forked from mrq/DL-Art-School
Fix DDP errors for discriminator
- Don't define training_net in define_optimizers - this drops the shell and leads to problems downstream - Get rid of support for multiple training nets per opt. This was half baked and needs a better solution if needed downstream.
This commit is contained in:
parent
8a83b1c716
commit
ea56eb61f0
|
@ -54,19 +54,14 @@ class ConfigurableStep(Module):
|
|||
# Must be called after networks are initialized and wrapped.
|
||||
def define_optimizers(self):
|
||||
training = self.step_opt['training']
|
||||
if isinstance(training, list):
|
||||
self.training_net = [self.get_network_for_name(t) for t in training]
|
||||
opt_configs = [self.step_opt['optimizer_params'][t] for t in training]
|
||||
nets = self.training_net
|
||||
training_net = self.get_network_for_name(training)
|
||||
# When only training one network, optimizer params can just embedded in the step params.
|
||||
if 'optimizer_params' not in self.step_opt.keys():
|
||||
opt_configs = [self.step_opt]
|
||||
else:
|
||||
self.training_net = self.get_network_for_name(training)
|
||||
# When only training one network, optimizer params can just embedded in the step params.
|
||||
if 'optimizer_params' not in self.step_opt.keys():
|
||||
opt_configs = [self.step_opt]
|
||||
else:
|
||||
opt_configs = [self.step_opt['optimizer_params']]
|
||||
nets = [self.training_net]
|
||||
training = [training]
|
||||
opt_configs = [self.step_opt['optimizer_params']]
|
||||
nets = [training_net]
|
||||
training = [training]
|
||||
self.optimizers = []
|
||||
for net_name, net, opt_config in zip(training, nets, opt_configs):
|
||||
optim_params = []
|
||||
|
@ -156,7 +151,7 @@ class ConfigurableStep(Module):
|
|||
'before' in loss.opt.keys() and self.env['step'] > loss.opt['before'] or \
|
||||
'every' in loss.opt.keys() and self.env['step'] % loss.opt['every'] != 0:
|
||||
continue
|
||||
l = loss(self.training_net, local_state)
|
||||
l = loss(self.get_network_for_name(self.step_opt['training']), local_state)
|
||||
total_loss += l * self.weights[loss_name]
|
||||
# Record metrics.
|
||||
if isinstance(l, torch.Tensor):
|
||||
|
@ -181,7 +176,7 @@ class ConfigurableStep(Module):
|
|||
# backwards? Because DDP uses the backward() pass as a synchronization point and there is not a good
|
||||
# way to simply bypass backward. If you want a more efficient way to specify a min_loss, use or
|
||||
# implement it at the loss level.
|
||||
self.training_net.zero_grad()
|
||||
self.get_network_for_name(self.step_opt['training']).zero_grad()
|
||||
self.loss_accumulator.increment_metric("%s_skipped_steps" % (self.get_training_network_name(),))
|
||||
|
||||
self.grads_generated = True
|
||||
|
|
Loading…
Reference in New Issue
Block a user