From ea56eb61f08a772d53c27e54affa8b012208ac8d Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 7 Dec 2020 12:50:57 -0700 Subject: [PATCH] Fix DDP errors for discriminator - Don't define training_net in define_optimizers - this drops the shell and leads to problems downstream - Get rid of support for multiple training nets per opt. This was half baked and needs a better solution if needed downstream. --- codes/models/steps/steps.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index 860bbbe7..8659c6f3 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -54,19 +54,14 @@ class ConfigurableStep(Module): # Must be called after networks are initialized and wrapped. def define_optimizers(self): training = self.step_opt['training'] - if isinstance(training, list): - self.training_net = [self.get_network_for_name(t) for t in training] - opt_configs = [self.step_opt['optimizer_params'][t] for t in training] - nets = self.training_net + training_net = self.get_network_for_name(training) + # When only training one network, optimizer params can just embedded in the step params. + if 'optimizer_params' not in self.step_opt.keys(): + opt_configs = [self.step_opt] else: - self.training_net = self.get_network_for_name(training) - # When only training one network, optimizer params can just embedded in the step params. - if 'optimizer_params' not in self.step_opt.keys(): - opt_configs = [self.step_opt] - else: - opt_configs = [self.step_opt['optimizer_params']] - nets = [self.training_net] - training = [training] + opt_configs = [self.step_opt['optimizer_params']] + nets = [training_net] + training = [training] self.optimizers = [] for net_name, net, opt_config in zip(training, nets, opt_configs): optim_params = [] @@ -156,7 +151,7 @@ class ConfigurableStep(Module): 'before' in loss.opt.keys() and self.env['step'] > loss.opt['before'] or \ 'every' in loss.opt.keys() and self.env['step'] % loss.opt['every'] != 0: continue - l = loss(self.training_net, local_state) + l = loss(self.get_network_for_name(self.step_opt['training']), local_state) total_loss += l * self.weights[loss_name] # Record metrics. if isinstance(l, torch.Tensor): @@ -181,7 +176,7 @@ class ConfigurableStep(Module): # backwards? Because DDP uses the backward() pass as a synchronization point and there is not a good # way to simply bypass backward. If you want a more efficient way to specify a min_loss, use or # implement it at the loss level. - self.training_net.zero_grad() + self.get_network_for_name(self.step_opt['training']).zero_grad() self.loss_accumulator.increment_metric("%s_skipped_steps" % (self.get_training_network_name(),)) self.grads_generated = True