Add convert_model.py and a hacky way to add extra layers to a model

This commit is contained in:
James Betker 2020-07-22 11:39:45 -06:00
parent 7f7e17e291
commit e3adafbeac
3 changed files with 108 additions and 0 deletions

View File

@ -182,6 +182,9 @@ class SRGANModel(BaseModel):
self.load() # load G and D if needed self.load() # load G and D if needed
self.load_random_corruptor() self.load_random_corruptor()
# Setting this to false triggers SRGAN to call the models update_model() function on the first iteration.
self.updated = True
def feed_data(self, data, need_GT=True): def feed_data(self, data, need_GT=True):
_profile = True _profile = True
if _profile: if _profile:
@ -203,6 +206,10 @@ class SRGANModel(BaseModel):
self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)] self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
self.pix = [t.to(self.device) for t in torch.chunk(data['PIX'], chunks=self.mega_batch_factor, dim=0)] self.pix = [t.to(self.device) for t in torch.chunk(data['PIX'], chunks=self.mega_batch_factor, dim=0)]
if not self.updated:
self.netG.module.update_model(self.optimizer_G, self.schedulers[0])
self.updated = True
def optimize_parameters(self, step): def optimize_parameters(self, step):
_profile = False _profile = False
if _profile: if _profile:

View File

@ -75,6 +75,32 @@ class GrowingSRGBase(nn.Module):
param_groups.append({'params': sw_param_group}) param_groups.append({'params': sw_param_group})
return param_groups return param_groups
# This is a hacky way of modifying the underlying model while training. Since changing the model means changing
# the optimizer and the scheduler, these things are fed in. For ProgressiveSrg, this function adds an additional
# switch to the end of the chain with depth=3 and an online time set at the end fo the function.
def update_model(self, opt, sched):
multiplx_fn = functools.partial(srg.ConvBasisMultiplexer, self.transformation_filters, self.switch_filters,
3, self.switch_processing_layers, self.transformation_counts)
pretransform_fn = functools.partial(ConvGnLelu, self.transformation_filters, self.transformation_filters, norm=False,
bias=False, weight_init_factor=.1)
transform_fn = functools.partial(srg.MultiConvBlock, self.transformation_filters, int(self.transformation_filters * 1.5),
self.transformation_filters, kernel_size=3, depth=self.trans_layers,
weight_init_factor=.1)
new_sw = srg.ConfigurableSwitchComputer(self.transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn,
transform_block=transform_fn,
transform_count=self.transformation_counts, init_temp=self.init_temperature,
add_scalable_noise_to_transforms=self.add_noise_to_transform,
attention_norm=False).to('cuda')
self.progressive_switches.append(new_sw)
new_sw_param_group = []
for k, v in new_sw.named_parameters():
if v.requires_grad:
new_sw_param_group.append(v)
opt.add_param_group({'params': new_sw_param_group})
self.progressive_schedule.append(150000)
sched.group_starts.append(150000)
def get_progressive_starts(self): def get_progressive_starts(self):
# The base param group starts at step 0, the rest are defined via progressive_switches. # The base param group starts at step 0, the rest are defined via progressive_switches.
return [0] + self.progressive_schedule return [0] + self.progressive_schedule

View File

@ -0,0 +1,75 @@
# Tool that can be used to add a new layer into an existing model save file. Primarily useful for "progressive"
# models which can be trained piecemeal.
import options.options as option
from models import create_model
import torch
import os
def get_model_for_opt_file(filename):
opt = option.parse(filename, is_train=True)
opt = option.dict_to_nonedict(opt)
model = create_model(opt)
return model, opt
def copy_state_dict_list(l_from, l_to):
for i, v in enumerate(l_from):
if isinstance(v, list):
copy_state_dict_list(v, l_to[i])
elif isinstance(v, dict):
copy_state_dict(v, l_to[i])
else:
l_to[i] = v
def copy_state_dict(dict_from, dict_to):
for k in dict_from.keys():
if k == 'optimizers':
for j in range(len(dict_from[k][0]['param_groups'])):
for p in dict_to[k][0]['param_groups'][j]['params']:
del dict_to[k][0]['state']
dict_to[k][0]['param_groups'][j] = dict_from[k][0]['param_groups'][j]
dict_to[k][0]['state'].update(dict_from[k][0]['state'])
print(len(dict_from[k][0].keys()), dict_from[k][0].keys())
print(len(dict_to[k][0].keys()), dict_to[k][0].keys())
assert k in dict_to.keys()
if isinstance(dict_from[k], dict):
copy_state_dict(dict_from[k], dict_to[k])
elif isinstance(dict_from[k], list):
copy_state_dict_list(dict_from[k], dict_to[k])
else:
dict_to[k] = dict_from[k]
return dict_to
if __name__ == "__main__":
os.chdir("..")
torch.backends.cudnn.benchmark = True
want_just_images = True
model_from, opt_from = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2.yml")
model_to, _ = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2_.yml")
'''
model_to.netG.module.update_for_step(1000000000000)
l = torch.nn.MSELoss()
o, _ = model_to.netG(torch.randn(1, 3, 64, 64))
l(o, torch.randn_like(o)).backward()
model_to.optimizer_G.step()
o = model_to.netD(torch.randn(1, 3, 128, 128))
l(o, torch.randn_like(o)).backward()
model_to.optimizer_D.step()
'''
torch.save(copy_state_dict(model_from.netG.state_dict(), model_to.netG.state_dict()), "converted_g.pth")
torch.save(copy_state_dict(model_from.netD.state_dict(), model_to.netD.state_dict()), "converted_d.pth")
# Also convert the state.
resume_state_from = torch.load(opt_from['path']['resume_state'])
resume_state_to = model_to.save_training_state(0, 0, return_state=True)
resume_state_from['optimizers'][0]['param_groups'].append(resume_state_to['optimizers'][0]['param_groups'][-1])
torch.save(resume_state_from, "converted_state.pth")