From fb595e72a4cd9f62990c238e58b5c8fe2cea0ca0 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 11 Sep 2020 22:57:06 -0600 Subject: [PATCH] Supporting infrastructure in ExtensibleTrainer to train spsr4 Need to be able to train 2 nets in one step: the backbone will be entirely separate with its own optimizer (for an extremely low LR). This functionality was already present, just not implemented correctly. --- codes/models/steps/injectors.py | 2 ++ codes/models/steps/losses.py | 9 +++--- codes/models/steps/steps.py | 56 ++++++++++++++++++++++----------- 3 files changed, 45 insertions(+), 22 deletions(-) diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index c0b3fb9c..897f854a 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -50,6 +50,8 @@ class ImageGeneratorInjector(Injector): results = gen(state[self.input]) new_state = {} if isinstance(self.output, list): + # Only dereference tuples or lists, not tensors. + assert isinstance(results, list) or isinstance(results, tuple) for i, k in enumerate(self.output): new_state[k] = results[i] else: diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 5ef1dde2..754dacc2 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -36,6 +36,7 @@ class ConfigurableLoss(nn.Module): self.env = env self.metrics = [] + # net is either a scalar network being trained or a list of networks being trained, depending on the configuration. def forward(self, net, state): raise NotImplementedError @@ -58,7 +59,7 @@ class PixLoss(ConfigurableLoss): self.opt = opt self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device']) - def forward(self, net, state): + def forward(self, _, state): return self.criterion(state[self.opt['fake']], state[self.opt['real']]) @@ -72,7 +73,7 @@ class FeatureLoss(ConfigurableLoss): if not env['opt']['dist']: self.netF = torch.nn.parallel.DataParallel(self.netF) - def forward(self, net, state): + def forward(self, _, state): with torch.no_grad(): logits_real = self.netF(state[self.opt['real']]) logits_fake = self.netF(state[self.opt['fake']]) @@ -94,7 +95,7 @@ class InterpretedFeatureLoss(ConfigurableLoss): self.netF_real = torch.nn.parallel.DataParallel(self.netF_real) self.netF_gen = torch.nn.parallel.DataParallel(self.netF_gen) - def forward(self, net, state): + def forward(self, _, state): logits_real = self.netF_real(state[self.opt['real']]) logits_fake = self.netF_gen(state[self.opt['fake']]) return self.criterion(logits_fake, logits_real) @@ -106,7 +107,7 @@ class GeneratorGanLoss(ConfigurableLoss): self.opt = opt self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device']) - def forward(self, net, state): + def forward(self, _, state): netD = self.env['discriminators'][self.opt['discriminator']] fake = extract_params_from_state(self.opt['fake'], state) if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']: diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index d1136883..b28f6d80 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -37,28 +37,45 @@ class ConfigurableStep(Module): self.weights[loss_name] = loss['weight'] self.losses = OrderedDict(losses) + def get_network_for_name(self, name): + return self.env['generators'][name] if name in self.env['generators'].keys() \ + else self.env['discriminators'][name] + # Subclasses should override this to define individual optimizers. They should all go into self.optimizers. # This default implementation defines a single optimizer for all Generator parameters. # Must be called after networks are initialized and wrapped. def define_optimizers(self): - self.training_net = self.env['generators'][self.step_opt['training']] \ - if self.step_opt['training'] in self.env['generators'].keys() \ - else self.env['discriminators'][self.step_opt['training']] - optim_params = [] - for k, v in self.training_net.named_parameters(): # can optimize for a part of the model - if v.requires_grad: - optim_params.append(v) + training = self.step_opt['training'] + if isinstance(training, list): + self.training_net = [self.get_network_for_name(t) for t in training] + opt_configs = [self.step_opt['optimizer_params'][t] for t in training] + nets = self.training_net + else: + self.training_net = self.get_network_for_name(training) + # When only training one network, optimizer params can just embedded in the step params. + if 'optimizer_params' not in self.step_opt.keys(): + opt_configs = [self.step_opt] else: - if self.env['rank'] <= 0: - logger.warning('Params [{:s}] will not optimize.'.format(k)) - if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam': - opt = torch.optim.Adam(optim_params, lr=self.step_opt['lr'], - weight_decay=self.step_opt['weight_decay'], - betas=(self.step_opt['beta1'], self.step_opt['beta2'])) - elif self.step_opt['optimizer'] == 'novograd': - opt = NovoGrad(optim_params, lr=self.step_opt['lr'], weight_decay=self.step_opt['weight_decay'], - betas=(self.step_opt['beta1'], self.step_opt['beta2'])) - self.optimizers = [opt] + opt_configs = [self.step_opt['optimizer_params']] + nets = [self.training_net] + self.optimizers = [] + for net, opt_config in zip(nets, opt_configs): + optim_params = [] + for k, v in net.named_parameters(): # can optimize for a part of the model + if v.requires_grad: + optim_params.append(v) + else: + if self.env['rank'] <= 0: + logger.warning('Params [{:s}] will not optimize.'.format(k)) + + if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam': + opt = torch.optim.Adam(optim_params, lr=opt_config['lr'], + weight_decay=opt_config['weight_decay'], + betas=(opt_config['beta1'], opt_config['beta2'])) + elif self.step_opt['optimizer'] == 'novograd': + opt = NovoGrad(optim_params, lr=opt_config['lr'], weight_decay=opt_config['weight_decay'], + betas=(opt_config['beta1'], opt_config['beta2'])) + self.optimizers.append(opt) # Returns all optimizers used in this step. def get_optimizers(self): @@ -72,7 +89,10 @@ class ConfigurableStep(Module): # Returns the names of the networks this step will train. Other networks will be frozen. def get_networks_trained(self): - return [self.step_opt['training']] + if isinstance(self.step_opt['training'], list): + return self.step_opt['training'] + else: + return [self.step_opt['training']] # Performs all forward and backward passes for this step given an input state. All input states are lists of # chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later