forked from mrq/DL-Art-School
Supporting infrastructure in ExtensibleTrainer to train spsr4
Need to be able to train 2 nets in one step: the backbone will be entirely separate with its own optimizer (for an extremely low LR). This functionality was already present, just not implemented correctly.
This commit is contained in:
parent
4e44bca611
commit
fb595e72a4
|
@ -50,6 +50,8 @@ class ImageGeneratorInjector(Injector):
|
|||
results = gen(state[self.input])
|
||||
new_state = {}
|
||||
if isinstance(self.output, list):
|
||||
# Only dereference tuples or lists, not tensors.
|
||||
assert isinstance(results, list) or isinstance(results, tuple)
|
||||
for i, k in enumerate(self.output):
|
||||
new_state[k] = results[i]
|
||||
else:
|
||||
|
|
|
@ -36,6 +36,7 @@ class ConfigurableLoss(nn.Module):
|
|||
self.env = env
|
||||
self.metrics = []
|
||||
|
||||
# net is either a scalar network being trained or a list of networks being trained, depending on the configuration.
|
||||
def forward(self, net, state):
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -58,7 +59,7 @@ class PixLoss(ConfigurableLoss):
|
|||
self.opt = opt
|
||||
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
||||
|
||||
def forward(self, net, state):
|
||||
def forward(self, _, state):
|
||||
return self.criterion(state[self.opt['fake']], state[self.opt['real']])
|
||||
|
||||
|
||||
|
@ -72,7 +73,7 @@ class FeatureLoss(ConfigurableLoss):
|
|||
if not env['opt']['dist']:
|
||||
self.netF = torch.nn.parallel.DataParallel(self.netF)
|
||||
|
||||
def forward(self, net, state):
|
||||
def forward(self, _, state):
|
||||
with torch.no_grad():
|
||||
logits_real = self.netF(state[self.opt['real']])
|
||||
logits_fake = self.netF(state[self.opt['fake']])
|
||||
|
@ -94,7 +95,7 @@ class InterpretedFeatureLoss(ConfigurableLoss):
|
|||
self.netF_real = torch.nn.parallel.DataParallel(self.netF_real)
|
||||
self.netF_gen = torch.nn.parallel.DataParallel(self.netF_gen)
|
||||
|
||||
def forward(self, net, state):
|
||||
def forward(self, _, state):
|
||||
logits_real = self.netF_real(state[self.opt['real']])
|
||||
logits_fake = self.netF_gen(state[self.opt['fake']])
|
||||
return self.criterion(logits_fake, logits_real)
|
||||
|
@ -106,7 +107,7 @@ class GeneratorGanLoss(ConfigurableLoss):
|
|||
self.opt = opt
|
||||
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
||||
|
||||
def forward(self, net, state):
|
||||
def forward(self, _, state):
|
||||
netD = self.env['discriminators'][self.opt['discriminator']]
|
||||
fake = extract_params_from_state(self.opt['fake'], state)
|
||||
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']:
|
||||
|
|
|
@ -37,28 +37,45 @@ class ConfigurableStep(Module):
|
|||
self.weights[loss_name] = loss['weight']
|
||||
self.losses = OrderedDict(losses)
|
||||
|
||||
def get_network_for_name(self, name):
|
||||
return self.env['generators'][name] if name in self.env['generators'].keys() \
|
||||
else self.env['discriminators'][name]
|
||||
|
||||
# Subclasses should override this to define individual optimizers. They should all go into self.optimizers.
|
||||
# This default implementation defines a single optimizer for all Generator parameters.
|
||||
# Must be called after networks are initialized and wrapped.
|
||||
def define_optimizers(self):
|
||||
self.training_net = self.env['generators'][self.step_opt['training']] \
|
||||
if self.step_opt['training'] in self.env['generators'].keys() \
|
||||
else self.env['discriminators'][self.step_opt['training']]
|
||||
training = self.step_opt['training']
|
||||
if isinstance(training, list):
|
||||
self.training_net = [self.get_network_for_name(t) for t in training]
|
||||
opt_configs = [self.step_opt['optimizer_params'][t] for t in training]
|
||||
nets = self.training_net
|
||||
else:
|
||||
self.training_net = self.get_network_for_name(training)
|
||||
# When only training one network, optimizer params can just embedded in the step params.
|
||||
if 'optimizer_params' not in self.step_opt.keys():
|
||||
opt_configs = [self.step_opt]
|
||||
else:
|
||||
opt_configs = [self.step_opt['optimizer_params']]
|
||||
nets = [self.training_net]
|
||||
self.optimizers = []
|
||||
for net, opt_config in zip(nets, opt_configs):
|
||||
optim_params = []
|
||||
for k, v in self.training_net.named_parameters(): # can optimize for a part of the model
|
||||
for k, v in net.named_parameters(): # can optimize for a part of the model
|
||||
if v.requires_grad:
|
||||
optim_params.append(v)
|
||||
else:
|
||||
if self.env['rank'] <= 0:
|
||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||
|
||||
if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam':
|
||||
opt = torch.optim.Adam(optim_params, lr=self.step_opt['lr'],
|
||||
weight_decay=self.step_opt['weight_decay'],
|
||||
betas=(self.step_opt['beta1'], self.step_opt['beta2']))
|
||||
opt = torch.optim.Adam(optim_params, lr=opt_config['lr'],
|
||||
weight_decay=opt_config['weight_decay'],
|
||||
betas=(opt_config['beta1'], opt_config['beta2']))
|
||||
elif self.step_opt['optimizer'] == 'novograd':
|
||||
opt = NovoGrad(optim_params, lr=self.step_opt['lr'], weight_decay=self.step_opt['weight_decay'],
|
||||
betas=(self.step_opt['beta1'], self.step_opt['beta2']))
|
||||
self.optimizers = [opt]
|
||||
opt = NovoGrad(optim_params, lr=opt_config['lr'], weight_decay=opt_config['weight_decay'],
|
||||
betas=(opt_config['beta1'], opt_config['beta2']))
|
||||
self.optimizers.append(opt)
|
||||
|
||||
# Returns all optimizers used in this step.
|
||||
def get_optimizers(self):
|
||||
|
@ -72,6 +89,9 @@ class ConfigurableStep(Module):
|
|||
|
||||
# Returns the names of the networks this step will train. Other networks will be frozen.
|
||||
def get_networks_trained(self):
|
||||
if isinstance(self.step_opt['training'], list):
|
||||
return self.step_opt['training']
|
||||
else:
|
||||
return [self.step_opt['training']]
|
||||
|
||||
# Performs all forward and backward passes for this step given an input state. All input states are lists of
|
||||
|
|
Loading…
Reference in New Issue
Block a user