Supporting infrastructure in ExtensibleTrainer to train spsr4

Need to be able to train 2 nets in one step: the backbone will be entirely separate
with its own optimizer (for an extremely low LR).

This functionality was already present, just not implemented correctly.
This commit is contained in:
James Betker 2020-09-11 22:57:06 -06:00
parent 4e44bca611
commit fb595e72a4
3 changed files with 45 additions and 22 deletions

View File

@ -50,6 +50,8 @@ class ImageGeneratorInjector(Injector):
results = gen(state[self.input])
new_state = {}
if isinstance(self.output, list):
# Only dereference tuples or lists, not tensors.
assert isinstance(results, list) or isinstance(results, tuple)
for i, k in enumerate(self.output):
new_state[k] = results[i]
else:

View File

@ -36,6 +36,7 @@ class ConfigurableLoss(nn.Module):
self.env = env
self.metrics = []
# net is either a scalar network being trained or a list of networks being trained, depending on the configuration.
def forward(self, net, state):
raise NotImplementedError
@ -58,7 +59,7 @@ class PixLoss(ConfigurableLoss):
self.opt = opt
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
def forward(self, net, state):
def forward(self, _, state):
return self.criterion(state[self.opt['fake']], state[self.opt['real']])
@ -72,7 +73,7 @@ class FeatureLoss(ConfigurableLoss):
if not env['opt']['dist']:
self.netF = torch.nn.parallel.DataParallel(self.netF)
def forward(self, net, state):
def forward(self, _, state):
with torch.no_grad():
logits_real = self.netF(state[self.opt['real']])
logits_fake = self.netF(state[self.opt['fake']])
@ -94,7 +95,7 @@ class InterpretedFeatureLoss(ConfigurableLoss):
self.netF_real = torch.nn.parallel.DataParallel(self.netF_real)
self.netF_gen = torch.nn.parallel.DataParallel(self.netF_gen)
def forward(self, net, state):
def forward(self, _, state):
logits_real = self.netF_real(state[self.opt['real']])
logits_fake = self.netF_gen(state[self.opt['fake']])
return self.criterion(logits_fake, logits_real)
@ -106,7 +107,7 @@ class GeneratorGanLoss(ConfigurableLoss):
self.opt = opt
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
def forward(self, net, state):
def forward(self, _, state):
netD = self.env['discriminators'][self.opt['discriminator']]
fake = extract_params_from_state(self.opt['fake'], state)
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']:

View File

@ -37,28 +37,45 @@ class ConfigurableStep(Module):
self.weights[loss_name] = loss['weight']
self.losses = OrderedDict(losses)
def get_network_for_name(self, name):
return self.env['generators'][name] if name in self.env['generators'].keys() \
else self.env['discriminators'][name]
# Subclasses should override this to define individual optimizers. They should all go into self.optimizers.
# This default implementation defines a single optimizer for all Generator parameters.
# Must be called after networks are initialized and wrapped.
def define_optimizers(self):
self.training_net = self.env['generators'][self.step_opt['training']] \
if self.step_opt['training'] in self.env['generators'].keys() \
else self.env['discriminators'][self.step_opt['training']]
training = self.step_opt['training']
if isinstance(training, list):
self.training_net = [self.get_network_for_name(t) for t in training]
opt_configs = [self.step_opt['optimizer_params'][t] for t in training]
nets = self.training_net
else:
self.training_net = self.get_network_for_name(training)
# When only training one network, optimizer params can just embedded in the step params.
if 'optimizer_params' not in self.step_opt.keys():
opt_configs = [self.step_opt]
else:
opt_configs = [self.step_opt['optimizer_params']]
nets = [self.training_net]
self.optimizers = []
for net, opt_config in zip(nets, opt_configs):
optim_params = []
for k, v in self.training_net.named_parameters(): # can optimize for a part of the model
for k, v in net.named_parameters(): # can optimize for a part of the model
if v.requires_grad:
optim_params.append(v)
else:
if self.env['rank'] <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k))
if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam':
opt = torch.optim.Adam(optim_params, lr=self.step_opt['lr'],
weight_decay=self.step_opt['weight_decay'],
betas=(self.step_opt['beta1'], self.step_opt['beta2']))
opt = torch.optim.Adam(optim_params, lr=opt_config['lr'],
weight_decay=opt_config['weight_decay'],
betas=(opt_config['beta1'], opt_config['beta2']))
elif self.step_opt['optimizer'] == 'novograd':
opt = NovoGrad(optim_params, lr=self.step_opt['lr'], weight_decay=self.step_opt['weight_decay'],
betas=(self.step_opt['beta1'], self.step_opt['beta2']))
self.optimizers = [opt]
opt = NovoGrad(optim_params, lr=opt_config['lr'], weight_decay=opt_config['weight_decay'],
betas=(opt_config['beta1'], opt_config['beta2']))
self.optimizers.append(opt)
# Returns all optimizers used in this step.
def get_optimizers(self):
@ -72,6 +89,9 @@ class ConfigurableStep(Module):
# Returns the names of the networks this step will train. Other networks will be frozen.
def get_networks_trained(self):
if isinstance(self.step_opt['training'], list):
return self.step_opt['training']
else:
return [self.step_opt['training']]
# Performs all forward and backward passes for this step given an input state. All input states are lists of