forked from mrq/DL-Art-School
Supporting infrastructure in ExtensibleTrainer to train spsr4
Need to be able to train 2 nets in one step: the backbone will be entirely separate with its own optimizer (for an extremely low LR). This functionality was already present, just not implemented correctly.
This commit is contained in:
parent
4e44bca611
commit
fb595e72a4
|
@ -50,6 +50,8 @@ class ImageGeneratorInjector(Injector):
|
||||||
results = gen(state[self.input])
|
results = gen(state[self.input])
|
||||||
new_state = {}
|
new_state = {}
|
||||||
if isinstance(self.output, list):
|
if isinstance(self.output, list):
|
||||||
|
# Only dereference tuples or lists, not tensors.
|
||||||
|
assert isinstance(results, list) or isinstance(results, tuple)
|
||||||
for i, k in enumerate(self.output):
|
for i, k in enumerate(self.output):
|
||||||
new_state[k] = results[i]
|
new_state[k] = results[i]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -36,6 +36,7 @@ class ConfigurableLoss(nn.Module):
|
||||||
self.env = env
|
self.env = env
|
||||||
self.metrics = []
|
self.metrics = []
|
||||||
|
|
||||||
|
# net is either a scalar network being trained or a list of networks being trained, depending on the configuration.
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -58,7 +59,7 @@ class PixLoss(ConfigurableLoss):
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
||||||
|
|
||||||
def forward(self, net, state):
|
def forward(self, _, state):
|
||||||
return self.criterion(state[self.opt['fake']], state[self.opt['real']])
|
return self.criterion(state[self.opt['fake']], state[self.opt['real']])
|
||||||
|
|
||||||
|
|
||||||
|
@ -72,7 +73,7 @@ class FeatureLoss(ConfigurableLoss):
|
||||||
if not env['opt']['dist']:
|
if not env['opt']['dist']:
|
||||||
self.netF = torch.nn.parallel.DataParallel(self.netF)
|
self.netF = torch.nn.parallel.DataParallel(self.netF)
|
||||||
|
|
||||||
def forward(self, net, state):
|
def forward(self, _, state):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits_real = self.netF(state[self.opt['real']])
|
logits_real = self.netF(state[self.opt['real']])
|
||||||
logits_fake = self.netF(state[self.opt['fake']])
|
logits_fake = self.netF(state[self.opt['fake']])
|
||||||
|
@ -94,7 +95,7 @@ class InterpretedFeatureLoss(ConfigurableLoss):
|
||||||
self.netF_real = torch.nn.parallel.DataParallel(self.netF_real)
|
self.netF_real = torch.nn.parallel.DataParallel(self.netF_real)
|
||||||
self.netF_gen = torch.nn.parallel.DataParallel(self.netF_gen)
|
self.netF_gen = torch.nn.parallel.DataParallel(self.netF_gen)
|
||||||
|
|
||||||
def forward(self, net, state):
|
def forward(self, _, state):
|
||||||
logits_real = self.netF_real(state[self.opt['real']])
|
logits_real = self.netF_real(state[self.opt['real']])
|
||||||
logits_fake = self.netF_gen(state[self.opt['fake']])
|
logits_fake = self.netF_gen(state[self.opt['fake']])
|
||||||
return self.criterion(logits_fake, logits_real)
|
return self.criterion(logits_fake, logits_real)
|
||||||
|
@ -106,7 +107,7 @@ class GeneratorGanLoss(ConfigurableLoss):
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
||||||
|
|
||||||
def forward(self, net, state):
|
def forward(self, _, state):
|
||||||
netD = self.env['discriminators'][self.opt['discriminator']]
|
netD = self.env['discriminators'][self.opt['discriminator']]
|
||||||
fake = extract_params_from_state(self.opt['fake'], state)
|
fake = extract_params_from_state(self.opt['fake'], state)
|
||||||
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']:
|
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']:
|
||||||
|
|
|
@ -37,28 +37,45 @@ class ConfigurableStep(Module):
|
||||||
self.weights[loss_name] = loss['weight']
|
self.weights[loss_name] = loss['weight']
|
||||||
self.losses = OrderedDict(losses)
|
self.losses = OrderedDict(losses)
|
||||||
|
|
||||||
|
def get_network_for_name(self, name):
|
||||||
|
return self.env['generators'][name] if name in self.env['generators'].keys() \
|
||||||
|
else self.env['discriminators'][name]
|
||||||
|
|
||||||
# Subclasses should override this to define individual optimizers. They should all go into self.optimizers.
|
# Subclasses should override this to define individual optimizers. They should all go into self.optimizers.
|
||||||
# This default implementation defines a single optimizer for all Generator parameters.
|
# This default implementation defines a single optimizer for all Generator parameters.
|
||||||
# Must be called after networks are initialized and wrapped.
|
# Must be called after networks are initialized and wrapped.
|
||||||
def define_optimizers(self):
|
def define_optimizers(self):
|
||||||
self.training_net = self.env['generators'][self.step_opt['training']] \
|
training = self.step_opt['training']
|
||||||
if self.step_opt['training'] in self.env['generators'].keys() \
|
if isinstance(training, list):
|
||||||
else self.env['discriminators'][self.step_opt['training']]
|
self.training_net = [self.get_network_for_name(t) for t in training]
|
||||||
optim_params = []
|
opt_configs = [self.step_opt['optimizer_params'][t] for t in training]
|
||||||
for k, v in self.training_net.named_parameters(): # can optimize for a part of the model
|
nets = self.training_net
|
||||||
if v.requires_grad:
|
else:
|
||||||
optim_params.append(v)
|
self.training_net = self.get_network_for_name(training)
|
||||||
|
# When only training one network, optimizer params can just embedded in the step params.
|
||||||
|
if 'optimizer_params' not in self.step_opt.keys():
|
||||||
|
opt_configs = [self.step_opt]
|
||||||
else:
|
else:
|
||||||
if self.env['rank'] <= 0:
|
opt_configs = [self.step_opt['optimizer_params']]
|
||||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
nets = [self.training_net]
|
||||||
if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam':
|
self.optimizers = []
|
||||||
opt = torch.optim.Adam(optim_params, lr=self.step_opt['lr'],
|
for net, opt_config in zip(nets, opt_configs):
|
||||||
weight_decay=self.step_opt['weight_decay'],
|
optim_params = []
|
||||||
betas=(self.step_opt['beta1'], self.step_opt['beta2']))
|
for k, v in net.named_parameters(): # can optimize for a part of the model
|
||||||
elif self.step_opt['optimizer'] == 'novograd':
|
if v.requires_grad:
|
||||||
opt = NovoGrad(optim_params, lr=self.step_opt['lr'], weight_decay=self.step_opt['weight_decay'],
|
optim_params.append(v)
|
||||||
betas=(self.step_opt['beta1'], self.step_opt['beta2']))
|
else:
|
||||||
self.optimizers = [opt]
|
if self.env['rank'] <= 0:
|
||||||
|
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||||
|
|
||||||
|
if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam':
|
||||||
|
opt = torch.optim.Adam(optim_params, lr=opt_config['lr'],
|
||||||
|
weight_decay=opt_config['weight_decay'],
|
||||||
|
betas=(opt_config['beta1'], opt_config['beta2']))
|
||||||
|
elif self.step_opt['optimizer'] == 'novograd':
|
||||||
|
opt = NovoGrad(optim_params, lr=opt_config['lr'], weight_decay=opt_config['weight_decay'],
|
||||||
|
betas=(opt_config['beta1'], opt_config['beta2']))
|
||||||
|
self.optimizers.append(opt)
|
||||||
|
|
||||||
# Returns all optimizers used in this step.
|
# Returns all optimizers used in this step.
|
||||||
def get_optimizers(self):
|
def get_optimizers(self):
|
||||||
|
@ -72,7 +89,10 @@ class ConfigurableStep(Module):
|
||||||
|
|
||||||
# Returns the names of the networks this step will train. Other networks will be frozen.
|
# Returns the names of the networks this step will train. Other networks will be frozen.
|
||||||
def get_networks_trained(self):
|
def get_networks_trained(self):
|
||||||
return [self.step_opt['training']]
|
if isinstance(self.step_opt['training'], list):
|
||||||
|
return self.step_opt['training']
|
||||||
|
else:
|
||||||
|
return [self.step_opt['training']]
|
||||||
|
|
||||||
# Performs all forward and backward passes for this step given an input state. All input states are lists of
|
# Performs all forward and backward passes for this step given an input state. All input states are lists of
|
||||||
# chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later
|
# chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later
|
||||||
|
|
Loading…
Reference in New Issue
Block a user