Merge remote-tracking branch 'origin/gan_lab' into gan_lab

# Conflicts:
#	codes/models/archs/SwitchedResidualGenerator_arch.py
This commit is contained in:
James Betker 2020-10-13 10:12:26 -06:00
commit bdf4c38899
3 changed files with 7 additions and 5 deletions

View File

@ -64,6 +64,9 @@ class ConfigurableLoss(nn.Module):
def extra_metrics(self): def extra_metrics(self):
return self.metrics return self.metrics
def clear_metrics(self):
self.metrics = []
def get_basic_criterion_for_name(name, device): def get_basic_criterion_for_name(name, device):
if name == 'l1': if name == 'l1':
@ -192,7 +195,6 @@ class DiscriminatorGanLoss(ConfigurableLoss):
self.losses_computed = 0 self.losses_computed = 0
def forward(self, net, state): def forward(self, net, state):
self.metrics = []
real = extract_params_from_state(self.opt['real'], state) real = extract_params_from_state(self.opt['real'], state)
real = [r.detach() for r in real] real = [r.detach() for r in real]
fake = extract_params_from_state(self.opt['fake'], state) fake = extract_params_from_state(self.opt['fake'], state)
@ -258,7 +260,6 @@ class GeometricSimilarityGeneratorLoss(ConfigurableLoss):
(functools.partial(torch.rot90, k=3, dims=[2,3]), functools.partial(torch.rot90, k=1, dims=[2,3]))]) (functools.partial(torch.rot90, k=3, dims=[2,3]), functools.partial(torch.rot90, k=1, dims=[2,3]))])
def forward(self, net, state): def forward(self, net, state):
self.metrics = []
net = self.env['generators'][self.generator] # Get the network from an explicit parameter. net = self.env['generators'][self.generator] # Get the network from an explicit parameter.
# The <net> parameter is not reliable for generator losses since often they are combined with many networks. # The <net> parameter is not reliable for generator losses since often they are combined with many networks.
fake = extract_params_from_state(self.opt['fake'], state) fake = extract_params_from_state(self.opt['fake'], state)
@ -305,7 +306,6 @@ class TranslationInvarianceLoss(ConfigurableLoss):
assert(self.patch_size > self.overlap) assert(self.patch_size > self.overlap)
def forward(self, net, state): def forward(self, net, state):
self.metrics = []
net = self.env['generators'][self.generator] # Get the network from an explicit parameter. net = self.env['generators'][self.generator] # Get the network from an explicit parameter.
# The <net> parameter is not reliable for generator losses since often they are combined with many networks. # The <net> parameter is not reliable for generator losses since often they are combined with many networks.
@ -356,7 +356,6 @@ class RecursiveInvarianceLoss(ConfigurableLoss):
assert(self.recursive_depth > 0) assert(self.recursive_depth > 0)
def forward(self, net, state): def forward(self, net, state):
self.metrics = []
net = self.env['generators'][self.generator] # Get the network from an explicit parameter. net = self.env['generators'][self.generator] # Get the network from an explicit parameter.
# The <net> parameter is not reliable for generator losses since they can be combined with many networks. # The <net> parameter is not reliable for generator losses since they can be combined with many networks.

View File

@ -156,6 +156,7 @@ class ConfigurableStep(Module):
self.loss_accumulator.add_loss(loss_name, l) self.loss_accumulator.add_loss(loss_name, l)
for n, v in loss.extra_metrics(): for n, v in loss.extra_metrics():
self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v) self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)
loss.clear_metrics()
# In some cases, the loss could not be set (e.g. all losses have 'after' # In some cases, the loss could not be set (e.g. all losses have 'after'
if isinstance(total_loss, torch.Tensor): if isinstance(total_loss, torch.Tensor):

View File

@ -61,7 +61,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env) super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env)
self.flow = opt['flow_network'] self.flow = opt['flow_network']
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0 self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
self.output_hq_index = opt['output_hq_index'] if 'output_index' in opt.keys() else 0 self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0
self.recurrent_index = opt['recurrent_index'] self.recurrent_index = opt['recurrent_index']
self.scale = opt['scale'] self.scale = opt['scale']
self.resample = Resample2d() self.resample = Resample2d()
@ -271,3 +271,5 @@ class PingPongLoss(ConfigurableLoss):
for i in range(cnt): for i in range(cnt):
img = imglist[:, i] img = imglist[:, i]
torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, ))) torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, )))