Add convert_model.py and a hacky way to add extra layers to a model
This commit is contained in:
parent
7f7e17e291
commit
e3adafbeac
|
@ -182,6 +182,9 @@ class SRGANModel(BaseModel):
|
|||
self.load() # load G and D if needed
|
||||
self.load_random_corruptor()
|
||||
|
||||
# Setting this to false triggers SRGAN to call the models update_model() function on the first iteration.
|
||||
self.updated = True
|
||||
|
||||
def feed_data(self, data, need_GT=True):
|
||||
_profile = True
|
||||
if _profile:
|
||||
|
@ -203,6 +206,10 @@ class SRGANModel(BaseModel):
|
|||
self.var_ref = [t.to(self.device) for t in torch.chunk(input_ref, chunks=self.mega_batch_factor, dim=0)]
|
||||
self.pix = [t.to(self.device) for t in torch.chunk(data['PIX'], chunks=self.mega_batch_factor, dim=0)]
|
||||
|
||||
if not self.updated:
|
||||
self.netG.module.update_model(self.optimizer_G, self.schedulers[0])
|
||||
self.updated = True
|
||||
|
||||
def optimize_parameters(self, step):
|
||||
_profile = False
|
||||
if _profile:
|
||||
|
|
|
@ -75,6 +75,32 @@ class GrowingSRGBase(nn.Module):
|
|||
param_groups.append({'params': sw_param_group})
|
||||
return param_groups
|
||||
|
||||
# This is a hacky way of modifying the underlying model while training. Since changing the model means changing
|
||||
# the optimizer and the scheduler, these things are fed in. For ProgressiveSrg, this function adds an additional
|
||||
# switch to the end of the chain with depth=3 and an online time set at the end fo the function.
|
||||
def update_model(self, opt, sched):
|
||||
multiplx_fn = functools.partial(srg.ConvBasisMultiplexer, self.transformation_filters, self.switch_filters,
|
||||
3, self.switch_processing_layers, self.transformation_counts)
|
||||
pretransform_fn = functools.partial(ConvGnLelu, self.transformation_filters, self.transformation_filters, norm=False,
|
||||
bias=False, weight_init_factor=.1)
|
||||
transform_fn = functools.partial(srg.MultiConvBlock, self.transformation_filters, int(self.transformation_filters * 1.5),
|
||||
self.transformation_filters, kernel_size=3, depth=self.trans_layers,
|
||||
weight_init_factor=.1)
|
||||
new_sw = srg.ConfigurableSwitchComputer(self.transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn,
|
||||
transform_block=transform_fn,
|
||||
transform_count=self.transformation_counts, init_temp=self.init_temperature,
|
||||
add_scalable_noise_to_transforms=self.add_noise_to_transform,
|
||||
attention_norm=False).to('cuda')
|
||||
self.progressive_switches.append(new_sw)
|
||||
new_sw_param_group = []
|
||||
for k, v in new_sw.named_parameters():
|
||||
if v.requires_grad:
|
||||
new_sw_param_group.append(v)
|
||||
opt.add_param_group({'params': new_sw_param_group})
|
||||
self.progressive_schedule.append(150000)
|
||||
sched.group_starts.append(150000)
|
||||
|
||||
def get_progressive_starts(self):
|
||||
# The base param group starts at step 0, the rest are defined via progressive_switches.
|
||||
return [0] + self.progressive_schedule
|
||||
|
|
75
codes/utils/convert_model.py
Normal file
75
codes/utils/convert_model.py
Normal file
|
@ -0,0 +1,75 @@
|
|||
# Tool that can be used to add a new layer into an existing model save file. Primarily useful for "progressive"
|
||||
# models which can be trained piecemeal.
|
||||
|
||||
import options.options as option
|
||||
from models import create_model
|
||||
import torch
|
||||
import os
|
||||
|
||||
def get_model_for_opt_file(filename):
|
||||
opt = option.parse(filename, is_train=True)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
model = create_model(opt)
|
||||
return model, opt
|
||||
|
||||
def copy_state_dict_list(l_from, l_to):
|
||||
for i, v in enumerate(l_from):
|
||||
if isinstance(v, list):
|
||||
copy_state_dict_list(v, l_to[i])
|
||||
elif isinstance(v, dict):
|
||||
copy_state_dict(v, l_to[i])
|
||||
else:
|
||||
l_to[i] = v
|
||||
|
||||
def copy_state_dict(dict_from, dict_to):
|
||||
for k in dict_from.keys():
|
||||
if k == 'optimizers':
|
||||
for j in range(len(dict_from[k][0]['param_groups'])):
|
||||
for p in dict_to[k][0]['param_groups'][j]['params']:
|
||||
del dict_to[k][0]['state']
|
||||
dict_to[k][0]['param_groups'][j] = dict_from[k][0]['param_groups'][j]
|
||||
dict_to[k][0]['state'].update(dict_from[k][0]['state'])
|
||||
print(len(dict_from[k][0].keys()), dict_from[k][0].keys())
|
||||
print(len(dict_to[k][0].keys()), dict_to[k][0].keys())
|
||||
assert k in dict_to.keys()
|
||||
if isinstance(dict_from[k], dict):
|
||||
copy_state_dict(dict_from[k], dict_to[k])
|
||||
elif isinstance(dict_from[k], list):
|
||||
copy_state_dict_list(dict_from[k], dict_to[k])
|
||||
else:
|
||||
dict_to[k] = dict_from[k]
|
||||
return dict_to
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.chdir("..")
|
||||
torch.backends.cudnn.benchmark = True
|
||||
want_just_images = True
|
||||
model_from, opt_from = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2.yml")
|
||||
model_to, _ = get_model_for_opt_file("../options/train_imgset_pixgan_progressive_srg2_.yml")
|
||||
|
||||
'''
|
||||
model_to.netG.module.update_for_step(1000000000000)
|
||||
l = torch.nn.MSELoss()
|
||||
o, _ = model_to.netG(torch.randn(1, 3, 64, 64))
|
||||
l(o, torch.randn_like(o)).backward()
|
||||
model_to.optimizer_G.step()
|
||||
o = model_to.netD(torch.randn(1, 3, 128, 128))
|
||||
l(o, torch.randn_like(o)).backward()
|
||||
model_to.optimizer_D.step()
|
||||
'''
|
||||
|
||||
torch.save(copy_state_dict(model_from.netG.state_dict(), model_to.netG.state_dict()), "converted_g.pth")
|
||||
torch.save(copy_state_dict(model_from.netD.state_dict(), model_to.netD.state_dict()), "converted_d.pth")
|
||||
|
||||
# Also convert the state.
|
||||
resume_state_from = torch.load(opt_from['path']['resume_state'])
|
||||
resume_state_to = model_to.save_training_state(0, 0, return_state=True)
|
||||
resume_state_from['optimizers'][0]['param_groups'].append(resume_state_to['optimizers'][0]['param_groups'][-1])
|
||||
torch.save(resume_state_from, "converted_state.pth")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user