Clear metrics properly
Holy cow, what a PITA bug.
This commit is contained in:
parent
4d52374e60
commit
8014f050ac
|
@ -64,6 +64,9 @@ class ConfigurableLoss(nn.Module):
|
|||
def extra_metrics(self):
|
||||
return self.metrics
|
||||
|
||||
def clear_metrics(self):
|
||||
self.metrics = []
|
||||
|
||||
|
||||
def get_basic_criterion_for_name(name, device):
|
||||
if name == 'l1':
|
||||
|
@ -192,7 +195,6 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
self.losses_computed = 0
|
||||
|
||||
def forward(self, net, state):
|
||||
self.metrics = []
|
||||
real = extract_params_from_state(self.opt['real'], state)
|
||||
real = [r.detach() for r in real]
|
||||
fake = extract_params_from_state(self.opt['fake'], state)
|
||||
|
@ -258,7 +260,6 @@ class GeometricSimilarityGeneratorLoss(ConfigurableLoss):
|
|||
(functools.partial(torch.rot90, k=3, dims=[2,3]), functools.partial(torch.rot90, k=1, dims=[2,3]))])
|
||||
|
||||
def forward(self, net, state):
|
||||
self.metrics = []
|
||||
net = self.env['generators'][self.generator] # Get the network from an explicit parameter.
|
||||
# The <net> parameter is not reliable for generator losses since often they are combined with many networks.
|
||||
fake = extract_params_from_state(self.opt['fake'], state)
|
||||
|
@ -305,7 +306,6 @@ class TranslationInvarianceLoss(ConfigurableLoss):
|
|||
assert(self.patch_size > self.overlap)
|
||||
|
||||
def forward(self, net, state):
|
||||
self.metrics = []
|
||||
net = self.env['generators'][self.generator] # Get the network from an explicit parameter.
|
||||
# The <net> parameter is not reliable for generator losses since often they are combined with many networks.
|
||||
|
||||
|
@ -356,7 +356,6 @@ class RecursiveInvarianceLoss(ConfigurableLoss):
|
|||
assert(self.recursive_depth > 0)
|
||||
|
||||
def forward(self, net, state):
|
||||
self.metrics = []
|
||||
net = self.env['generators'][self.generator] # Get the network from an explicit parameter.
|
||||
# The <net> parameter is not reliable for generator losses since they can be combined with many networks.
|
||||
|
||||
|
|
|
@ -156,6 +156,7 @@ class ConfigurableStep(Module):
|
|||
self.loss_accumulator.add_loss(loss_name, l)
|
||||
for n, v in loss.extra_metrics():
|
||||
self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)
|
||||
loss.clear_metrics()
|
||||
|
||||
# In some cases, the loss could not be set (e.g. all losses have 'after'
|
||||
if isinstance(total_loss, torch.Tensor):
|
||||
|
|
|
@ -61,7 +61,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env)
|
||||
self.flow = opt['flow_network']
|
||||
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
|
||||
self.output_hq_index = opt['output_hq_index'] if 'output_index' in opt.keys() else 0
|
||||
self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0
|
||||
self.recurrent_index = opt['recurrent_index']
|
||||
self.scale = opt['scale']
|
||||
self.resample = Resample2d()
|
||||
|
@ -271,3 +271,5 @@ class PingPongLoss(ConfigurableLoss):
|
|||
for i in range(cnt):
|
||||
img = imglist[:, i]
|
||||
torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, )))
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user