Add more sophisticated mechanism for balancing GAN losses

This commit is contained in:
James Betker 2020-10-02 22:53:42 -06:00
parent 39865ca3df
commit dd9d7b27ac
2 changed files with 35 additions and 5 deletions

View File

@ -125,6 +125,12 @@ class GeneratorGanLoss(ConfigurableLoss):
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
self.noise = None if 'noise' not in opt.keys() else opt['noise']
self.detach_real = opt['detach_real'] if 'detach_real' in opt.keys() else True
# This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance
# generators and discriminators by essentially having them skip steps while their counterparts "catch up".
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
self.rb_ptr = 0
self.losses_computed = 0
def forward(self, _, state):
netD = self.env['discriminators'][self.opt['discriminator']]
@ -144,16 +150,23 @@ class GeneratorGanLoss(ConfigurableLoss):
fake = nfake
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']:
pred_g_fake = netD(*fake)
return self.criterion(pred_g_fake, True)
loss = self.criterion(pred_g_fake, True)
elif self.opt['gan_type'] == 'ragan':
pred_d_real = netD(*real)
if self.detach_real:
pred_d_real = pred_d_real.detach()
pred_g_fake = netD(*fake)
return (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) +
loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) +
self.criterion(pred_g_fake - torch.mean(pred_d_real), True)) / 2
else:
raise NotImplementedError
self.loss_rotating_buffer[self.rb_ptr] = loss.item()
self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0]
if torch.mean(self.loss_rotating_buffer) < self.min_loss:
return 0
self.losses_computed += 1
self.metrics.append(("loss_counter", self.losses_computed))
return loss
class DiscriminatorGanLoss(ConfigurableLoss):
@ -162,6 +175,12 @@ class DiscriminatorGanLoss(ConfigurableLoss):
self.opt = opt
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
self.noise = None if 'noise' not in opt.keys() else opt['noise']
# This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance
# generators and discriminators by essentially having them skip steps while their counterparts "catch up".
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
self.rb_ptr = 0
self.losses_computed = 0
def forward(self, net, state):
self.metrics = []
@ -190,14 +209,21 @@ class DiscriminatorGanLoss(ConfigurableLoss):
l_real = self.criterion(d_real, True)
l_fake = self.criterion(d_fake, False)
l_total = l_real + l_fake
return l_total
loss = l_total
elif self.opt['gan_type'] == 'ragan':
d_fake_diff = d_fake - torch.mean(d_real)
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
return (self.criterion(d_real - torch.mean(d_fake), True) +
loss = (self.criterion(d_real - torch.mean(d_fake), True) +
self.criterion(d_fake_diff, False))
else:
raise NotImplementedError
self.loss_rotating_buffer[self.rb_ptr] = loss.item()
self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0]
if torch.mean(self.loss_rotating_buffer) < self.min_loss:
return 0
self.losses_computed += 1
self.metrics.append(("loss_counter", self.losses_computed))
return loss
# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an

View File

@ -10,7 +10,11 @@ class LossAccumulator:
if name not in self.buffers.keys():
self.buffers[name] = (0, torch.zeros(self.buffer_sz))
i, buf = self.buffers[name]
buf[i] = tensor.detach().cpu()
# Can take tensors or just plain python numbers.
if isinstance(tensor, torch.Tensor):
buf[i] = tensor.detach().cpu()
else:
buf[i] = tensor
self.buffers[name] = ((i+1) % self.buffer_sz, buf)
def as_dict(self):