DL-Art-School/codes/models/spleeter/util.py
2021-09-09 23:14:32 -06:00

91 lines
3.3 KiB
Python

import numpy as np
import tensorflow as tf
from .unet import UNet
def tf2pytorch(checkpoint_path, num_instrumments):
tf_vars = {}
init_vars = tf.train.list_variables(checkpoint_path)
# print(init_vars)
for name, shape in init_vars:
try:
# print('Loading TF Weight {} with shape {}'.format(name, shape))
data = tf.train.load_variable(checkpoint_path, name)
tf_vars[name] = data
except Exception as e:
print('Load error')
conv_idx = 0
tconv_idx = 0
bn_idx = 0
outputs = []
for i in range(num_instrumments):
output = {}
outputs.append(output)
for j in range(1,7):
if conv_idx == 0:
conv_suffix = ""
else:
conv_suffix = "_" + str(conv_idx)
if bn_idx == 0:
bn_suffix = ""
else:
bn_suffix = "_" + str(bn_idx)
output['down{}_conv.weight'.format(j)] = np.transpose(
tf_vars["conv2d{}/kernel".format(conv_suffix)], (3, 2, 0, 1))
# print('conv dtype: ',output['down{}.0.weight'.format(j)].dtype)
output['down{}_conv.bias'.format(
j)] = tf_vars["conv2d{}/bias".format(conv_suffix)]
output['down{}_act.0.weight'.format(
j)] = tf_vars["batch_normalization{}/gamma".format(bn_suffix)]
output['down{}_act.0.bias'.format(
j)] = tf_vars["batch_normalization{}/beta".format(bn_suffix)]
output['down{}_act.0.running_mean'.format(
j)] = tf_vars['batch_normalization{}/moving_mean'.format(bn_suffix)]
output['down{}_act.0.running_var'.format(
j)] = tf_vars['batch_normalization{}/moving_variance'.format(bn_suffix)]
conv_idx += 1
bn_idx += 1
# up blocks
for j in range(1, 7):
if tconv_idx == 0:
tconv_suffix = ""
else:
tconv_suffix = "_" + str(tconv_idx)
if bn_idx == 0:
bn_suffix = ""
else:
bn_suffix= "_" + str(bn_idx)
output['up{}.0.weight'.format(j)] = np.transpose(
tf_vars["conv2d_transpose{}/kernel".format(tconv_suffix)], (3,2,0, 1))
output['up{}.0.bias'.format(
j)] = tf_vars["conv2d_transpose{}/bias".format(tconv_suffix)]
output['up{}.2.weight'.format(
j)] = tf_vars["batch_normalization{}/gamma".format(bn_suffix)]
output['up{}.2.bias'.format(
j)] = tf_vars["batch_normalization{}/beta".format(bn_suffix)]
output['up{}.2.running_mean'.format(
j)] = tf_vars['batch_normalization{}/moving_mean'.format(bn_suffix)]
output['up{}.2.running_var'.format(
j)] = tf_vars['batch_normalization{}/moving_variance'.format(bn_suffix)]
tconv_idx += 1
bn_idx += 1
if conv_idx == 0:
suffix = ""
else:
suffix = "_" + str(conv_idx)
output['up7.0.weight'] = np.transpose(
tf_vars['conv2d{}/kernel'.format(suffix)], (3, 2, 0, 1))
output['up7.0.bias'] = tf_vars['conv2d{}/bias'.format(suffix)]
conv_idx += 1
return outputs