Allow discriminator noise to be injected at the loss level, cleans up configs

This commit is contained in:
James Betker 2020-09-19 21:47:52 -06:00
parent e9a39bfa14
commit 3138f98fbc

View File

@ -129,12 +129,18 @@ class DiscriminatorGanLoss(ConfigurableLoss):
super(DiscriminatorGanLoss, self).__init__(opt, env) super(DiscriminatorGanLoss, self).__init__(opt, env)
self.opt = opt self.opt = opt
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device']) self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
self.noise = None if 'noise' not in opt.keys() else opt['noise']
def forward(self, net, state): def forward(self, net, state):
self.metrics = [] self.metrics = []
real = extract_params_from_state(self.opt['real'], state) real = extract_params_from_state(self.opt['real'], state)
fake = extract_params_from_state(self.opt['fake'], state) fake = extract_params_from_state(self.opt['fake'], state)
fake = [f.detach() for f in fake] fake = [f.detach() for f in fake]
if self.noise:
# An assumption is made that the first input to the discriminator is what we want to add noise to. If not,
# use a explicit formulation of adding noise (using an injector)
real[0] += torch.randn_like(real[0])
fake[0] += torch.randn_like(fake[0])
d_real = net(*real) d_real = net(*real)
d_fake = net(*fake) d_fake = net(*fake)