diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 754dacc2..af9d8805 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -129,12 +129,18 @@ class DiscriminatorGanLoss(ConfigurableLoss): super(DiscriminatorGanLoss, self).__init__(opt, env) self.opt = opt self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device']) + self.noise = None if 'noise' not in opt.keys() else opt['noise'] def forward(self, net, state): self.metrics = [] real = extract_params_from_state(self.opt['real'], state) fake = extract_params_from_state(self.opt['fake'], state) fake = [f.detach() for f in fake] + if self.noise: + # An assumption is made that the first input to the discriminator is what we want to add noise to. If not, + # use a explicit formulation of adding noise (using an injector) + real[0] += torch.randn_like(real[0]) + fake[0] += torch.randn_like(fake[0]) d_real = net(*real) d_fake = net(*fake)