Allow discriminator noise to be injected at the loss level, cleans up configs
This commit is contained in:
parent
e9a39bfa14
commit
3138f98fbc
|
@ -129,12 +129,18 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
super(DiscriminatorGanLoss, self).__init__(opt, env)
|
super(DiscriminatorGanLoss, self).__init__(opt, env)
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
||||||
|
self.noise = None if 'noise' not in opt.keys() else opt['noise']
|
||||||
|
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
self.metrics = []
|
self.metrics = []
|
||||||
real = extract_params_from_state(self.opt['real'], state)
|
real = extract_params_from_state(self.opt['real'], state)
|
||||||
fake = extract_params_from_state(self.opt['fake'], state)
|
fake = extract_params_from_state(self.opt['fake'], state)
|
||||||
fake = [f.detach() for f in fake]
|
fake = [f.detach() for f in fake]
|
||||||
|
if self.noise:
|
||||||
|
# An assumption is made that the first input to the discriminator is what we want to add noise to. If not,
|
||||||
|
# use a explicit formulation of adding noise (using an injector)
|
||||||
|
real[0] += torch.randn_like(real[0])
|
||||||
|
fake[0] += torch.randn_like(fake[0])
|
||||||
d_real = net(*real)
|
d_real = net(*real)
|
||||||
d_fake = net(*fake)
|
d_fake = net(*fake)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user