forked from mrq/DL-Art-School
74 lines
2.7 KiB
Python
74 lines
2.7 KiB
Python
|
# Tool that can be used to add a new layer into an existing model save file. Primarily useful for "progressive"
|
||
|
# models which can be trained piecemeal.
|
||
|
|
||
|
import options.options as option
|
||
|
from models import create_model
|
||
|
import torch
|
||
|
import os
|
||
|
|
||
|
def get_model_for_opt_file(filename):
|
||
|
opt = option.parse(filename, is_train=True)
|
||
|
opt = option.dict_to_nonedict(opt)
|
||
|
model = create_model(opt)
|
||
|
return model, opt
|
||
|
|
||
|
def copy_state_dict_list(l_from, l_to):
|
||
|
for i, v in enumerate(l_from):
|
||
|
if isinstance(v, list):
|
||
|
copy_state_dict_list(v, l_to[i])
|
||
|
elif isinstance(v, dict):
|
||
|
copy_state_dict(v, l_to[i])
|
||
|
else:
|
||
|
l_to[i] = v
|
||
|
|
||
|
def copy_state_dict(dict_from, dict_to):
|
||
|
for k in dict_from.keys():
|
||
|
if k == 'optimizers':
|
||
|
for j in range(len(dict_from[k][0]['param_groups'])):
|
||
|
for p in dict_to[k][0]['param_groups'][j]['params']:
|
||
|
del dict_to[k][0]['state']
|
||
|
dict_to[k][0]['param_groups'][j] = dict_from[k][0]['param_groups'][j]
|
||
|
dict_to[k][0]['state'].update(dict_from[k][0]['state'])
|
||
|
print(len(dict_from[k][0].keys()), dict_from[k][0].keys())
|
||
|
print(len(dict_to[k][0].keys()), dict_to[k][0].keys())
|
||
|
assert k in dict_to.keys()
|
||
|
if isinstance(dict_from[k], dict):
|
||
|
copy_state_dict(dict_from[k], dict_to[k])
|
||
|
elif isinstance(dict_from[k], list):
|
||
|
copy_state_dict_list(dict_from[k], dict_to[k])
|
||
|
else:
|
||
|
dict_to[k] = dict_from[k]
|
||
|
return dict_to
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
os.chdir("..")
|
||
|
model_from, opt_from = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2.yml")
|
||
|
model_to, _ = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2_.yml")
|
||
|
|
||
|
'''
|
||
|
model_to.netG.module.update_for_step(1000000000000)
|
||
|
l = torch.nn.MSELoss()
|
||
|
o, _ = model_to.netG(torch.randn(1, 3, 64, 64))
|
||
|
l(o, torch.randn_like(o)).backward()
|
||
|
model_to.optimizer_G.step()
|
||
|
o = model_to.netD(torch.randn(1, 3, 128, 128))
|
||
|
l(o, torch.randn_like(o)).backward()
|
||
|
model_to.optimizer_D.step()
|
||
|
'''
|
||
|
|
||
|
torch.save(copy_state_dict(model_from.netG.state_dict(), model_to.netG.state_dict()), "converted_g.pth")
|
||
|
torch.save(copy_state_dict(model_from.netD.state_dict(), model_to.netD.state_dict()), "converted_d.pth")
|
||
|
|
||
|
# Also convert the state.
|
||
|
resume_state_from = torch.load(opt_from['path']['resume_state'])
|
||
|
resume_state_to = model_to.save_training_state(0, 0, return_state=True)
|
||
|
resume_state_from['optimizers'][0]['param_groups'].append(resume_state_to['optimizers'][0]['param_groups'][-1])
|
||
|
torch.save(resume_state_from, "converted_state.pth")
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|