DL-Art-School/codes/utils/convert_model.py
James Betker dbf6147504 Add switched discriminator
The logic is that the discriminator may be incapable of providing a truly
targeted loss for all image regions since it has to be too generic
(basically the same argument for the switched generator). So add some
switches in! See how it works!
2020-07-22 20:52:59 -06:00

74 lines
2.7 KiB
Python

# Tool that can be used to add a new layer into an existing model save file. Primarily useful for "progressive"
# models which can be trained piecemeal.
import options.options as option
from models import create_model
import torch
import os
def get_model_for_opt_file(filename):
opt = option.parse(filename, is_train=True)
opt = option.dict_to_nonedict(opt)
model = create_model(opt)
return model, opt
def copy_state_dict_list(l_from, l_to):
for i, v in enumerate(l_from):
if isinstance(v, list):
copy_state_dict_list(v, l_to[i])
elif isinstance(v, dict):
copy_state_dict(v, l_to[i])
else:
l_to[i] = v
def copy_state_dict(dict_from, dict_to):
for k in dict_from.keys():
if k == 'optimizers':
for j in range(len(dict_from[k][0]['param_groups'])):
for p in dict_to[k][0]['param_groups'][j]['params']:
del dict_to[k][0]['state']
dict_to[k][0]['param_groups'][j] = dict_from[k][0]['param_groups'][j]
dict_to[k][0]['state'].update(dict_from[k][0]['state'])
print(len(dict_from[k][0].keys()), dict_from[k][0].keys())
print(len(dict_to[k][0].keys()), dict_to[k][0].keys())
assert k in dict_to.keys()
if isinstance(dict_from[k], dict):
copy_state_dict(dict_from[k], dict_to[k])
elif isinstance(dict_from[k], list):
copy_state_dict_list(dict_from[k], dict_to[k])
else:
dict_to[k] = dict_from[k]
return dict_to
if __name__ == "__main__":
os.chdir("..")
model_from, opt_from = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2.yml")
model_to, _ = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2_.yml")
'''
model_to.netG.module.update_for_step(1000000000000)
l = torch.nn.MSELoss()
o, _ = model_to.netG(torch.randn(1, 3, 64, 64))
l(o, torch.randn_like(o)).backward()
model_to.optimizer_G.step()
o = model_to.netD(torch.randn(1, 3, 128, 128))
l(o, torch.randn_like(o)).backward()
model_to.optimizer_D.step()
'''
torch.save(copy_state_dict(model_from.netG.state_dict(), model_to.netG.state_dict()), "converted_g.pth")
torch.save(copy_state_dict(model_from.netD.state_dict(), model_to.netD.state_dict()), "converted_d.pth")
# Also convert the state.
resume_state_from = torch.load(opt_from['path']['resume_state'])
resume_state_to = model_to.save_training_state(0, 0, return_state=True)
resume_state_from['optimizers'][0]['param_groups'].append(resume_state_to['optimizers'][0]['param_groups'][-1])
torch.save(resume_state_from, "converted_state.pth")