diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 6ecb4085..14d995e1 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -125,6 +125,12 @@ class GeneratorGanLoss(ConfigurableLoss): self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device']) self.noise = None if 'noise' not in opt.keys() else opt['noise'] self.detach_real = opt['detach_real'] if 'detach_real' in opt.keys() else True + # This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance + # generators and discriminators by essentially having them skip steps while their counterparts "catch up". + self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0 + self.loss_rotating_buffer = torch.zeros(10, requires_grad=False) + self.rb_ptr = 0 + self.losses_computed = 0 def forward(self, _, state): netD = self.env['discriminators'][self.opt['discriminator']] @@ -144,16 +150,23 @@ class GeneratorGanLoss(ConfigurableLoss): fake = nfake if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']: pred_g_fake = netD(*fake) - return self.criterion(pred_g_fake, True) + loss = self.criterion(pred_g_fake, True) elif self.opt['gan_type'] == 'ragan': pred_d_real = netD(*real) if self.detach_real: pred_d_real = pred_d_real.detach() pred_g_fake = netD(*fake) - return (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) + + loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) + self.criterion(pred_g_fake - torch.mean(pred_d_real), True)) / 2 else: raise NotImplementedError + self.loss_rotating_buffer[self.rb_ptr] = loss.item() + self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0] + if torch.mean(self.loss_rotating_buffer) < self.min_loss: + return 0 + self.losses_computed += 1 + self.metrics.append(("loss_counter", self.losses_computed)) + return loss class DiscriminatorGanLoss(ConfigurableLoss): @@ -162,6 +175,12 @@ class DiscriminatorGanLoss(ConfigurableLoss): self.opt = opt self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device']) self.noise = None if 'noise' not in opt.keys() else opt['noise'] + # This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance + # generators and discriminators by essentially having them skip steps while their counterparts "catch up". + self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0 + self.loss_rotating_buffer = torch.zeros(10, requires_grad=False) + self.rb_ptr = 0 + self.losses_computed = 0 def forward(self, net, state): self.metrics = [] @@ -190,14 +209,21 @@ class DiscriminatorGanLoss(ConfigurableLoss): l_real = self.criterion(d_real, True) l_fake = self.criterion(d_fake, False) l_total = l_real + l_fake - return l_total + loss = l_total elif self.opt['gan_type'] == 'ragan': d_fake_diff = d_fake - torch.mean(d_real) self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff))) - return (self.criterion(d_real - torch.mean(d_fake), True) + + loss = (self.criterion(d_real - torch.mean(d_fake), True) + self.criterion(d_fake_diff, False)) else: raise NotImplementedError + self.loss_rotating_buffer[self.rb_ptr] = loss.item() + self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0] + if torch.mean(self.loss_rotating_buffer) < self.min_loss: + return 0 + self.losses_computed += 1 + self.metrics.append(("loss_counter", self.losses_computed)) + return loss # Computes a loss created by comparing the output of a generator to the output from the same generator when fed an diff --git a/codes/utils/loss_accumulator.py b/codes/utils/loss_accumulator.py index 6c2cb6c2..a8762360 100644 --- a/codes/utils/loss_accumulator.py +++ b/codes/utils/loss_accumulator.py @@ -10,7 +10,11 @@ class LossAccumulator: if name not in self.buffers.keys(): self.buffers[name] = (0, torch.zeros(self.buffer_sz)) i, buf = self.buffers[name] - buf[i] = tensor.detach().cpu() + # Can take tensors or just plain python numbers. + if isinstance(tensor, torch.Tensor): + buf[i] = tensor.detach().cpu() + else: + buf[i] = tensor self.buffers[name] = ((i+1) % self.buffer_sz, buf) def as_dict(self):