forked from mrq/DL-Art-School
dbf6147504
The logic is that the discriminator may be incapable of providing a truly targeted loss for all image regions since it has to be too generic (basically the same argument for the switched generator). So add some switches in! See how it works!
74 lines
2.7 KiB
Python
74 lines
2.7 KiB
Python
# Tool that can be used to add a new layer into an existing model save file. Primarily useful for "progressive"
|
|
# models which can be trained piecemeal.
|
|
|
|
import options.options as option
|
|
from models import create_model
|
|
import torch
|
|
import os
|
|
|
|
def get_model_for_opt_file(filename):
|
|
opt = option.parse(filename, is_train=True)
|
|
opt = option.dict_to_nonedict(opt)
|
|
model = create_model(opt)
|
|
return model, opt
|
|
|
|
def copy_state_dict_list(l_from, l_to):
|
|
for i, v in enumerate(l_from):
|
|
if isinstance(v, list):
|
|
copy_state_dict_list(v, l_to[i])
|
|
elif isinstance(v, dict):
|
|
copy_state_dict(v, l_to[i])
|
|
else:
|
|
l_to[i] = v
|
|
|
|
def copy_state_dict(dict_from, dict_to):
|
|
for k in dict_from.keys():
|
|
if k == 'optimizers':
|
|
for j in range(len(dict_from[k][0]['param_groups'])):
|
|
for p in dict_to[k][0]['param_groups'][j]['params']:
|
|
del dict_to[k][0]['state']
|
|
dict_to[k][0]['param_groups'][j] = dict_from[k][0]['param_groups'][j]
|
|
dict_to[k][0]['state'].update(dict_from[k][0]['state'])
|
|
print(len(dict_from[k][0].keys()), dict_from[k][0].keys())
|
|
print(len(dict_to[k][0].keys()), dict_to[k][0].keys())
|
|
assert k in dict_to.keys()
|
|
if isinstance(dict_from[k], dict):
|
|
copy_state_dict(dict_from[k], dict_to[k])
|
|
elif isinstance(dict_from[k], list):
|
|
copy_state_dict_list(dict_from[k], dict_to[k])
|
|
else:
|
|
dict_to[k] = dict_from[k]
|
|
return dict_to
|
|
|
|
if __name__ == "__main__":
|
|
os.chdir("..")
|
|
model_from, opt_from = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2.yml")
|
|
model_to, _ = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2_.yml")
|
|
|
|
'''
|
|
model_to.netG.module.update_for_step(1000000000000)
|
|
l = torch.nn.MSELoss()
|
|
o, _ = model_to.netG(torch.randn(1, 3, 64, 64))
|
|
l(o, torch.randn_like(o)).backward()
|
|
model_to.optimizer_G.step()
|
|
o = model_to.netD(torch.randn(1, 3, 128, 128))
|
|
l(o, torch.randn_like(o)).backward()
|
|
model_to.optimizer_D.step()
|
|
'''
|
|
|
|
torch.save(copy_state_dict(model_from.netG.state_dict(), model_to.netG.state_dict()), "converted_g.pth")
|
|
torch.save(copy_state_dict(model_from.netD.state_dict(), model_to.netD.state_dict()), "converted_d.pth")
|
|
|
|
# Also convert the state.
|
|
resume_state_from = torch.load(opt_from['path']['resume_state'])
|
|
resume_state_to = model_to.save_training_state(0, 0, return_state=True)
|
|
resume_state_from['optimizers'][0]['param_groups'].append(resume_state_to['optimizers'][0]['param_groups'][-1])
|
|
torch.save(resume_state_from, "converted_state.pth")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|