Fix DDP errors for discriminator

- Don't define training_net in define_optimizers - this drops the shell and leads to problems downstream
- Get rid of support for multiple training nets per opt. This was half baked and needs a better solution if needed
   downstream.
This commit is contained in:
James Betker 2020-12-07 12:50:57 -07:00
parent 8a83b1c716
commit ea56eb61f0

View File

@ -54,19 +54,14 @@ class ConfigurableStep(Module):
# Must be called after networks are initialized and wrapped.
def define_optimizers(self):
training = self.step_opt['training']
if isinstance(training, list):
self.training_net = [self.get_network_for_name(t) for t in training]
opt_configs = [self.step_opt['optimizer_params'][t] for t in training]
nets = self.training_net
training_net = self.get_network_for_name(training)
# When only training one network, optimizer params can just embedded in the step params.
if 'optimizer_params' not in self.step_opt.keys():
opt_configs = [self.step_opt]
else:
self.training_net = self.get_network_for_name(training)
# When only training one network, optimizer params can just embedded in the step params.
if 'optimizer_params' not in self.step_opt.keys():
opt_configs = [self.step_opt]
else:
opt_configs = [self.step_opt['optimizer_params']]
nets = [self.training_net]
training = [training]
opt_configs = [self.step_opt['optimizer_params']]
nets = [training_net]
training = [training]
self.optimizers = []
for net_name, net, opt_config in zip(training, nets, opt_configs):
optim_params = []
@ -156,7 +151,7 @@ class ConfigurableStep(Module):
'before' in loss.opt.keys() and self.env['step'] > loss.opt['before'] or \
'every' in loss.opt.keys() and self.env['step'] % loss.opt['every'] != 0:
continue
l = loss(self.training_net, local_state)
l = loss(self.get_network_for_name(self.step_opt['training']), local_state)
total_loss += l * self.weights[loss_name]
# Record metrics.
if isinstance(l, torch.Tensor):
@ -181,7 +176,7 @@ class ConfigurableStep(Module):
# backwards? Because DDP uses the backward() pass as a synchronization point and there is not a good
# way to simply bypass backward. If you want a more efficient way to specify a min_loss, use or
# implement it at the loss level.
self.training_net.zero_grad()
self.get_network_for_name(self.step_opt['training']).zero_grad()
self.loss_accumulator.increment_metric("%s_skipped_steps" % (self.get_training_network_name(),))
self.grads_generated = True