Fix bug with discriminator noise addition

It wasn't using the scale and was applying the noise to the
underlying state variable.
This commit is contained in:
James Betker 2020-09-20 12:00:27 -06:00
parent dab8ab8a8f
commit 17dd99b29b

View File

@ -137,10 +137,17 @@ class DiscriminatorGanLoss(ConfigurableLoss):
fake = extract_params_from_state(self.opt['fake'], state) fake = extract_params_from_state(self.opt['fake'], state)
fake = [f.detach() for f in fake] fake = [f.detach() for f in fake]
if self.noise: if self.noise:
# An assumption is made that the first input to the discriminator is what we want to add noise to. If not, nreal = []
# use a explicit formulation of adding noise (using an injector) nfake = []
real[0] += torch.randn_like(real[0]) for i, t in enumerate(real):
fake[0] += torch.randn_like(fake[0]) if isinstance(t, torch.Tensor):
nreal.append(t + torch.randn_like(t) * self.noise)
nfake.append(fake[i] + torch.randn_like(t) * self.noise)
else:
nreal.append(t)
nfake.append(fake[i])
real = nreal
fake = nfake
d_real = net(*real) d_real = net(*real)
d_fake = net(*fake) d_fake = net(*fake)