diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index af9d8805..73fe38b3 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -137,10 +137,17 @@ class DiscriminatorGanLoss(ConfigurableLoss): fake = extract_params_from_state(self.opt['fake'], state) fake = [f.detach() for f in fake] if self.noise: - # An assumption is made that the first input to the discriminator is what we want to add noise to. If not, - # use a explicit formulation of adding noise (using an injector) - real[0] += torch.randn_like(real[0]) - fake[0] += torch.randn_like(fake[0]) + nreal = [] + nfake = [] + for i, t in enumerate(real): + if isinstance(t, torch.Tensor): + nreal.append(t + torch.randn_like(t) * self.noise) + nfake.append(fake[i] + torch.randn_like(t) * self.noise) + else: + nreal.append(t) + nfake.append(fake[i]) + real = nreal + fake = nfake d_real = net(*real) d_fake = net(*fake)