From e3adafbeaceab43858e93f7b21e22ae48dae8549 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 22 Jul 2020 11:39:45 -0600 Subject: [PATCH] Add convert_model.py and a hacky way to add extra layers to a model --- codes/models/SRGAN_model.py | 7 +++ codes/models/archs/ProgressiveSrg_arch.py | 26 ++++++++ codes/utils/convert_model.py | 75 +++++++++++++++++++++++ 3 files changed, 108 insertions(+) create mode 100644 codes/utils/convert_model.py diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 697c5575..60baf25b 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -182,6 +182,9 @@ class SRGANModel(BaseModel): self.load() # load G and D if needed self.load_random_corruptor() + # Setting this to false triggers SRGAN to call the models update_model() function on the first iteration. + self.updated = True + def feed_data(self, data, need_GT=True): _profile = True if _profile: @@ -203,6 +206,10 @@ class SRGANModel(BaseModel): self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)] self.pix = [t.to(self.device) for t in torch.chunk(data['PIX'], chunks=self.mega_batch_factor, dim=0)] + if not self.updated: + self.netG.module.update_model(self.optimizer_G, self.schedulers[0]) + self.updated = True + def optimize_parameters(self, step): _profile = False if _profile: diff --git a/codes/models/archs/ProgressiveSrg_arch.py b/codes/models/archs/ProgressiveSrg_arch.py index df09cdce..ee0043f8 100644 --- a/codes/models/archs/ProgressiveSrg_arch.py +++ b/codes/models/archs/ProgressiveSrg_arch.py @@ -75,6 +75,32 @@ class GrowingSRGBase(nn.Module): param_groups.append({'params': sw_param_group}) return param_groups + # This is a hacky way of modifying the underlying model while training. Since changing the model means changing + # the optimizer and the scheduler, these things are fed in. For ProgressiveSrg, this function adds an additional +# switch to the end of the chain with depth=3 and an online time set at the end fo the function. + def update_model(self, opt, sched): + multiplx_fn = functools.partial(srg.ConvBasisMultiplexer, self.transformation_filters, self.switch_filters, + 3, self.switch_processing_layers, self.transformation_counts) + pretransform_fn = functools.partial(ConvGnLelu, self.transformation_filters, self.transformation_filters, norm=False, + bias=False, weight_init_factor=.1) + transform_fn = functools.partial(srg.MultiConvBlock, self.transformation_filters, int(self.transformation_filters * 1.5), + self.transformation_filters, kernel_size=3, depth=self.trans_layers, + weight_init_factor=.1) + new_sw = srg.ConfigurableSwitchComputer(self.transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn, + transform_block=transform_fn, + transform_count=self.transformation_counts, init_temp=self.init_temperature, + add_scalable_noise_to_transforms=self.add_noise_to_transform, + attention_norm=False).to('cuda') + self.progressive_switches.append(new_sw) + new_sw_param_group = [] + for k, v in new_sw.named_parameters(): + if v.requires_grad: + new_sw_param_group.append(v) + opt.add_param_group({'params': new_sw_param_group}) + self.progressive_schedule.append(150000) + sched.group_starts.append(150000) + def get_progressive_starts(self): # The base param group starts at step 0, the rest are defined via progressive_switches. return [0] + self.progressive_schedule diff --git a/codes/utils/convert_model.py b/codes/utils/convert_model.py new file mode 100644 index 00000000..6d31aac6 --- /dev/null +++ b/codes/utils/convert_model.py @@ -0,0 +1,75 @@ +# Tool that can be used to add a new layer into an existing model save file. Primarily useful for "progressive" +# models which can be trained piecemeal. + +import options.options as option +from models import create_model +import torch +import os + +def get_model_for_opt_file(filename): + opt = option.parse(filename, is_train=True) + opt = option.dict_to_nonedict(opt) + model = create_model(opt) + return model, opt + +def copy_state_dict_list(l_from, l_to): + for i, v in enumerate(l_from): + if isinstance(v, list): + copy_state_dict_list(v, l_to[i]) + elif isinstance(v, dict): + copy_state_dict(v, l_to[i]) + else: + l_to[i] = v + +def copy_state_dict(dict_from, dict_to): + for k in dict_from.keys(): + if k == 'optimizers': + for j in range(len(dict_from[k][0]['param_groups'])): + for p in dict_to[k][0]['param_groups'][j]['params']: + del dict_to[k][0]['state'] + dict_to[k][0]['param_groups'][j] = dict_from[k][0]['param_groups'][j] + dict_to[k][0]['state'].update(dict_from[k][0]['state']) + print(len(dict_from[k][0].keys()), dict_from[k][0].keys()) + print(len(dict_to[k][0].keys()), dict_to[k][0].keys()) + assert k in dict_to.keys() + if isinstance(dict_from[k], dict): + copy_state_dict(dict_from[k], dict_to[k]) + elif isinstance(dict_from[k], list): + copy_state_dict_list(dict_from[k], dict_to[k]) + else: + dict_to[k] = dict_from[k] + return dict_to + +if __name__ == "__main__": + os.chdir("..") + torch.backends.cudnn.benchmark = True + want_just_images = True + model_from, opt_from = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2.yml") + model_to, _ = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2_.yml") + + ''' + model_to.netG.module.update_for_step(1000000000000) + l = torch.nn.MSELoss() + o, _ = model_to.netG(torch.randn(1, 3, 64, 64)) + l(o, torch.randn_like(o)).backward() + model_to.optimizer_G.step() + o = model_to.netD(torch.randn(1, 3, 128, 128)) + l(o, torch.randn_like(o)).backward() + model_to.optimizer_D.step() + ''' + + torch.save(copy_state_dict(model_from.netG.state_dict(), model_to.netG.state_dict()), "converted_g.pth") + torch.save(copy_state_dict(model_from.netD.state_dict(), model_to.netD.state_dict()), "converted_d.pth") + + # Also convert the state. + resume_state_from = torch.load(opt_from['path']['resume_state']) + resume_state_to = model_to.save_training_state(0, 0, return_state=True) + resume_state_from['optimizers'][0]['param_groups'].append(resume_state_to['optimizers'][0]['param_groups'][-1]) + torch.save(resume_state_from, "converted_state.pth") + + + + + + +