Clear metrics properly

Holy cow, what a PITA bug.
This commit is contained in:
James Betker 2020-10-13 10:07:49 -06:00
parent 4d52374e60
commit 8014f050ac
3 changed files with 7 additions and 5 deletions

View File

@ -64,6 +64,9 @@ class ConfigurableLoss(nn.Module):
def extra_metrics(self):
return self.metrics
def clear_metrics(self):
self.metrics = []
def get_basic_criterion_for_name(name, device):
if name == 'l1':
@ -192,7 +195,6 @@ class DiscriminatorGanLoss(ConfigurableLoss):
self.losses_computed = 0
def forward(self, net, state):
self.metrics = []
real = extract_params_from_state(self.opt['real'], state)
real = [r.detach() for r in real]
fake = extract_params_from_state(self.opt['fake'], state)
@ -258,7 +260,6 @@ class GeometricSimilarityGeneratorLoss(ConfigurableLoss):
(functools.partial(torch.rot90, k=3, dims=[2,3]), functools.partial(torch.rot90, k=1, dims=[2,3]))])
def forward(self, net, state):
self.metrics = []
net = self.env['generators'][self.generator] # Get the network from an explicit parameter.
# The <net> parameter is not reliable for generator losses since often they are combined with many networks.
fake = extract_params_from_state(self.opt['fake'], state)
@ -305,7 +306,6 @@ class TranslationInvarianceLoss(ConfigurableLoss):
assert(self.patch_size > self.overlap)
def forward(self, net, state):
self.metrics = []
net = self.env['generators'][self.generator] # Get the network from an explicit parameter.
# The <net> parameter is not reliable for generator losses since often they are combined with many networks.
@ -356,7 +356,6 @@ class RecursiveInvarianceLoss(ConfigurableLoss):
assert(self.recursive_depth > 0)
def forward(self, net, state):
self.metrics = []
net = self.env['generators'][self.generator] # Get the network from an explicit parameter.
# The <net> parameter is not reliable for generator losses since they can be combined with many networks.

View File

@ -156,6 +156,7 @@ class ConfigurableStep(Module):
self.loss_accumulator.add_loss(loss_name, l)
for n, v in loss.extra_metrics():
self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)
loss.clear_metrics()
# In some cases, the loss could not be set (e.g. all losses have 'after'
if isinstance(total_loss, torch.Tensor):

View File

@ -61,7 +61,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env)
self.flow = opt['flow_network']
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
self.output_hq_index = opt['output_hq_index'] if 'output_index' in opt.keys() else 0
self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0
self.recurrent_index = opt['recurrent_index']
self.scale = opt['scale']
self.resample = Resample2d()
@ -271,3 +271,5 @@ class PingPongLoss(ConfigurableLoss):
for i in range(cnt):
img = imglist[:, i]
torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, )))