91 lines
3.3 KiB
Python
91 lines
3.3 KiB
Python
|
import numpy as np
|
||
|
import tensorflow as tf
|
||
|
|
||
|
from .unet import UNet
|
||
|
|
||
|
|
||
|
def tf2pytorch(checkpoint_path, num_instrumments):
|
||
|
tf_vars = {}
|
||
|
init_vars = tf.train.list_variables(checkpoint_path)
|
||
|
# print(init_vars)
|
||
|
for name, shape in init_vars:
|
||
|
try:
|
||
|
# print('Loading TF Weight {} with shape {}'.format(name, shape))
|
||
|
data = tf.train.load_variable(checkpoint_path, name)
|
||
|
tf_vars[name] = data
|
||
|
except Exception as e:
|
||
|
print('Load error')
|
||
|
conv_idx = 0
|
||
|
tconv_idx = 0
|
||
|
bn_idx = 0
|
||
|
outputs = []
|
||
|
for i in range(num_instrumments):
|
||
|
output = {}
|
||
|
outputs.append(output)
|
||
|
|
||
|
for j in range(1,7):
|
||
|
if conv_idx == 0:
|
||
|
conv_suffix = ""
|
||
|
else:
|
||
|
conv_suffix = "_" + str(conv_idx)
|
||
|
|
||
|
if bn_idx == 0:
|
||
|
bn_suffix = ""
|
||
|
else:
|
||
|
bn_suffix = "_" + str(bn_idx)
|
||
|
|
||
|
output['down{}_conv.weight'.format(j)] = np.transpose(
|
||
|
tf_vars["conv2d{}/kernel".format(conv_suffix)], (3, 2, 0, 1))
|
||
|
# print('conv dtype: ',output['down{}.0.weight'.format(j)].dtype)
|
||
|
output['down{}_conv.bias'.format(
|
||
|
j)] = tf_vars["conv2d{}/bias".format(conv_suffix)]
|
||
|
|
||
|
output['down{}_act.0.weight'.format(
|
||
|
j)] = tf_vars["batch_normalization{}/gamma".format(bn_suffix)]
|
||
|
output['down{}_act.0.bias'.format(
|
||
|
j)] = tf_vars["batch_normalization{}/beta".format(bn_suffix)]
|
||
|
output['down{}_act.0.running_mean'.format(
|
||
|
j)] = tf_vars['batch_normalization{}/moving_mean'.format(bn_suffix)]
|
||
|
output['down{}_act.0.running_var'.format(
|
||
|
j)] = tf_vars['batch_normalization{}/moving_variance'.format(bn_suffix)]
|
||
|
|
||
|
conv_idx += 1
|
||
|
bn_idx += 1
|
||
|
|
||
|
# up blocks
|
||
|
for j in range(1, 7):
|
||
|
if tconv_idx == 0:
|
||
|
tconv_suffix = ""
|
||
|
else:
|
||
|
tconv_suffix = "_" + str(tconv_idx)
|
||
|
|
||
|
if bn_idx == 0:
|
||
|
bn_suffix = ""
|
||
|
else:
|
||
|
bn_suffix= "_" + str(bn_idx)
|
||
|
|
||
|
output['up{}.0.weight'.format(j)] = np.transpose(
|
||
|
tf_vars["conv2d_transpose{}/kernel".format(tconv_suffix)], (3,2,0, 1))
|
||
|
output['up{}.0.bias'.format(
|
||
|
j)] = tf_vars["conv2d_transpose{}/bias".format(tconv_suffix)]
|
||
|
output['up{}.2.weight'.format(
|
||
|
j)] = tf_vars["batch_normalization{}/gamma".format(bn_suffix)]
|
||
|
output['up{}.2.bias'.format(
|
||
|
j)] = tf_vars["batch_normalization{}/beta".format(bn_suffix)]
|
||
|
output['up{}.2.running_mean'.format(
|
||
|
j)] = tf_vars['batch_normalization{}/moving_mean'.format(bn_suffix)]
|
||
|
output['up{}.2.running_var'.format(
|
||
|
j)] = tf_vars['batch_normalization{}/moving_variance'.format(bn_suffix)]
|
||
|
tconv_idx += 1
|
||
|
bn_idx += 1
|
||
|
|
||
|
if conv_idx == 0:
|
||
|
suffix = ""
|
||
|
else:
|
||
|
suffix = "_" + str(conv_idx)
|
||
|
output['up7.0.weight'] = np.transpose(
|
||
|
tf_vars['conv2d{}/kernel'.format(suffix)], (3, 2, 0, 1))
|
||
|
output['up7.0.bias'] = tf_vars['conv2d{}/bias'.format(suffix)]
|
||
|
conv_idx += 1
|
||
|
|
||
|
return outputs
|