From 8014f050ac2488f87e5ee025259ef64b675a1f6f Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 13 Oct 2020 10:07:49 -0600 Subject: [PATCH] Clear metrics properly Holy cow, what a PITA bug. --- codes/models/steps/losses.py | 7 +++---- codes/models/steps/steps.py | 1 + codes/models/steps/tecogan_losses.py | 4 +++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 3e6eeb6f..0e37e2c7 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -64,6 +64,9 @@ class ConfigurableLoss(nn.Module): def extra_metrics(self): return self.metrics + def clear_metrics(self): + self.metrics = [] + def get_basic_criterion_for_name(name, device): if name == 'l1': @@ -192,7 +195,6 @@ class DiscriminatorGanLoss(ConfigurableLoss): self.losses_computed = 0 def forward(self, net, state): - self.metrics = [] real = extract_params_from_state(self.opt['real'], state) real = [r.detach() for r in real] fake = extract_params_from_state(self.opt['fake'], state) @@ -258,7 +260,6 @@ class GeometricSimilarityGeneratorLoss(ConfigurableLoss): (functools.partial(torch.rot90, k=3, dims=[2,3]), functools.partial(torch.rot90, k=1, dims=[2,3]))]) def forward(self, net, state): - self.metrics = [] net = self.env['generators'][self.generator] # Get the network from an explicit parameter. # The parameter is not reliable for generator losses since often they are combined with many networks. fake = extract_params_from_state(self.opt['fake'], state) @@ -305,7 +306,6 @@ class TranslationInvarianceLoss(ConfigurableLoss): assert(self.patch_size > self.overlap) def forward(self, net, state): - self.metrics = [] net = self.env['generators'][self.generator] # Get the network from an explicit parameter. # The parameter is not reliable for generator losses since often they are combined with many networks. @@ -356,7 +356,6 @@ class RecursiveInvarianceLoss(ConfigurableLoss): assert(self.recursive_depth > 0) def forward(self, net, state): - self.metrics = [] net = self.env['generators'][self.generator] # Get the network from an explicit parameter. # The parameter is not reliable for generator losses since they can be combined with many networks. diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index f7c9ca66..6cf21b09 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -156,6 +156,7 @@ class ConfigurableStep(Module): self.loss_accumulator.add_loss(loss_name, l) for n, v in loss.extra_metrics(): self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v) + loss.clear_metrics() # In some cases, the loss could not be set (e.g. all losses have 'after' if isinstance(total_loss, torch.Tensor): diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index 9f320506..c95e214b 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -61,7 +61,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector): super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env) self.flow = opt['flow_network'] self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0 - self.output_hq_index = opt['output_hq_index'] if 'output_index' in opt.keys() else 0 + self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0 self.recurrent_index = opt['recurrent_index'] self.scale = opt['scale'] self.resample = Resample2d() @@ -271,3 +271,5 @@ class PingPongLoss(ConfigurableLoss): for i in range(cnt): img = imglist[:, i] torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, ))) + +