Allow discriminator noise to be injected at the loss level, cleans up configs
This commit is contained in:
parent
e9a39bfa14
commit
3138f98fbc
|
@ -129,12 +129,18 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
super(DiscriminatorGanLoss, self).__init__(opt, env)
|
||||
self.opt = opt
|
||||
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
||||
self.noise = None if 'noise' not in opt.keys() else opt['noise']
|
||||
|
||||
def forward(self, net, state):
|
||||
self.metrics = []
|
||||
real = extract_params_from_state(self.opt['real'], state)
|
||||
fake = extract_params_from_state(self.opt['fake'], state)
|
||||
fake = [f.detach() for f in fake]
|
||||
if self.noise:
|
||||
# An assumption is made that the first input to the discriminator is what we want to add noise to. If not,
|
||||
# use a explicit formulation of adding noise (using an injector)
|
||||
real[0] += torch.randn_like(real[0])
|
||||
fake[0] += torch.randn_like(fake[0])
|
||||
d_real = net(*real)
|
||||
d_fake = net(*fake)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user