diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 72a3dfee..798264c0 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -140,9 +140,10 @@ class GeneratorGanLoss(ConfigurableLoss): # This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance # generators and discriminators by essentially having them skip steps while their counterparts "catch up". self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0 - self.loss_rotating_buffer = torch.zeros(10, requires_grad=False) - self.rb_ptr = 0 - self.losses_computed = 0 + if self.min_loss != 0: + self.loss_rotating_buffer = torch.zeros(10, requires_grad=False) + self.rb_ptr = 0 + self.losses_computed = 0 def forward(self, _, state): netD = self.env['discriminators'][self.opt['discriminator']] @@ -172,12 +173,13 @@ class GeneratorGanLoss(ConfigurableLoss): self.criterion(pred_g_fake - torch.mean(pred_d_real), True)) / 2 else: raise NotImplementedError - self.loss_rotating_buffer[self.rb_ptr] = loss.item() - self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0] - if torch.mean(self.loss_rotating_buffer) < self.min_loss: - return 0 - self.losses_computed += 1 - self.metrics.append(("loss_counter", self.losses_computed)) + if self.min_loss != 0: + self.loss_rotating_buffer[self.rb_ptr] = loss.item() + self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0] + if torch.mean(self.loss_rotating_buffer) < self.min_loss: + return 0 + self.losses_computed += 1 + self.metrics.append(("loss_counter", self.losses_computed)) return loss @@ -190,9 +192,10 @@ class DiscriminatorGanLoss(ConfigurableLoss): # This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance # generators and discriminators by essentially having them skip steps while their counterparts "catch up". self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0 - self.loss_rotating_buffer = torch.zeros(10, requires_grad=False) - self.rb_ptr = 0 - self.losses_computed = 0 + if self.min_loss != 0: + self.loss_rotating_buffer = torch.zeros(10, requires_grad=False) + self.rb_ptr = 0 + self.losses_computed = 0 def forward(self, net, state): real = extract_params_from_state(self.opt['real'], state) @@ -228,12 +231,13 @@ class DiscriminatorGanLoss(ConfigurableLoss): self.criterion(d_fake_diff, False)) else: raise NotImplementedError - self.loss_rotating_buffer[self.rb_ptr] = loss.item() - self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0] - if torch.mean(self.loss_rotating_buffer) < self.min_loss: - return 0 - self.losses_computed += 1 - self.metrics.append(("loss_counter", self.losses_computed)) + if self.min_loss != 0: + self.loss_rotating_buffer[self.rb_ptr] = loss.item() + self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0] + if torch.mean(self.loss_rotating_buffer) < self.min_loss: + return 0 + self.losses_computed += 1 + self.metrics.append(("loss_counter", self.losses_computed)) return loss @@ -397,6 +401,12 @@ class RecurrentLoss(ConfigurableLoss): total_loss += self.loss(net, st) return total_loss + def extra_metrics(self): + return self.loss.extra_metrics() + + def clear_metrics(self): + self.loss.clear_metrics() + # Loss that pulls a tensor from dim 1 of the input and feeds it into a "sub" loss. class ForElementLoss(ConfigurableLoss): @@ -414,3 +424,9 @@ class ForElementLoss(ConfigurableLoss): st['_real'] = state[self.opt['real']][:, self.index] st['_fake'] = state[self.opt['fake']][:, self.index] return self.loss(net, st) + + def extra_metrics(self): + return self.loss.extra_metrics() + + def clear_metrics(self): + self.loss.clear_metrics() diff --git a/codes/models/steps/progressive_zoom.py b/codes/models/steps/progressive_zoom.py index a648e235..f7661af8 100644 --- a/codes/models/steps/progressive_zoom.py +++ b/codes/models/steps/progressive_zoom.py @@ -64,6 +64,7 @@ class ProgressiveGeneratorInjector(Injector): inputs = extract_params_from_state(self.input, state) lq_inputs = inputs[self.input_lq_index] hq_inputs = state[self.hq_key] + output = self.output if not isinstance(inputs, list): inputs = [inputs] if not isinstance(self.output, list):