Add more sophisticated mechanism for balancing GAN losses
This commit is contained in:
parent
39865ca3df
commit
dd9d7b27ac
|
@ -125,6 +125,12 @@ class GeneratorGanLoss(ConfigurableLoss):
|
|||
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
||||
self.noise = None if 'noise' not in opt.keys() else opt['noise']
|
||||
self.detach_real = opt['detach_real'] if 'detach_real' in opt.keys() else True
|
||||
# This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance
|
||||
# generators and discriminators by essentially having them skip steps while their counterparts "catch up".
|
||||
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
|
||||
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
|
||||
self.rb_ptr = 0
|
||||
self.losses_computed = 0
|
||||
|
||||
def forward(self, _, state):
|
||||
netD = self.env['discriminators'][self.opt['discriminator']]
|
||||
|
@ -144,16 +150,23 @@ class GeneratorGanLoss(ConfigurableLoss):
|
|||
fake = nfake
|
||||
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']:
|
||||
pred_g_fake = netD(*fake)
|
||||
return self.criterion(pred_g_fake, True)
|
||||
loss = self.criterion(pred_g_fake, True)
|
||||
elif self.opt['gan_type'] == 'ragan':
|
||||
pred_d_real = netD(*real)
|
||||
if self.detach_real:
|
||||
pred_d_real = pred_d_real.detach()
|
||||
pred_g_fake = netD(*fake)
|
||||
return (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||
loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||
self.criterion(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.loss_rotating_buffer[self.rb_ptr] = loss.item()
|
||||
self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0]
|
||||
if torch.mean(self.loss_rotating_buffer) < self.min_loss:
|
||||
return 0
|
||||
self.losses_computed += 1
|
||||
self.metrics.append(("loss_counter", self.losses_computed))
|
||||
return loss
|
||||
|
||||
|
||||
class DiscriminatorGanLoss(ConfigurableLoss):
|
||||
|
@ -162,6 +175,12 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
self.opt = opt
|
||||
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
||||
self.noise = None if 'noise' not in opt.keys() else opt['noise']
|
||||
# This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance
|
||||
# generators and discriminators by essentially having them skip steps while their counterparts "catch up".
|
||||
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
|
||||
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
|
||||
self.rb_ptr = 0
|
||||
self.losses_computed = 0
|
||||
|
||||
def forward(self, net, state):
|
||||
self.metrics = []
|
||||
|
@ -190,14 +209,21 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
l_real = self.criterion(d_real, True)
|
||||
l_fake = self.criterion(d_fake, False)
|
||||
l_total = l_real + l_fake
|
||||
return l_total
|
||||
loss = l_total
|
||||
elif self.opt['gan_type'] == 'ragan':
|
||||
d_fake_diff = d_fake - torch.mean(d_real)
|
||||
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
||||
return (self.criterion(d_real - torch.mean(d_fake), True) +
|
||||
loss = (self.criterion(d_real - torch.mean(d_fake), True) +
|
||||
self.criterion(d_fake_diff, False))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.loss_rotating_buffer[self.rb_ptr] = loss.item()
|
||||
self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0]
|
||||
if torch.mean(self.loss_rotating_buffer) < self.min_loss:
|
||||
return 0
|
||||
self.losses_computed += 1
|
||||
self.metrics.append(("loss_counter", self.losses_computed))
|
||||
return loss
|
||||
|
||||
|
||||
# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an
|
||||
|
|
|
@ -10,7 +10,11 @@ class LossAccumulator:
|
|||
if name not in self.buffers.keys():
|
||||
self.buffers[name] = (0, torch.zeros(self.buffer_sz))
|
||||
i, buf = self.buffers[name]
|
||||
buf[i] = tensor.detach().cpu()
|
||||
# Can take tensors or just plain python numbers.
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
buf[i] = tensor.detach().cpu()
|
||||
else:
|
||||
buf[i] = tensor
|
||||
self.buffers[name] = ((i+1) % self.buffer_sz, buf)
|
||||
|
||||
def as_dict(self):
|
||||
|
|
Loading…
Reference in New Issue
Block a user