Merge remote-tracking branch 'origin/gan_lab' into gan_lab
# Conflicts: # codes/models/archs/SwitchedResidualGenerator_arch.py
This commit is contained in:
commit
bdf4c38899
|
@ -64,6 +64,9 @@ class ConfigurableLoss(nn.Module):
|
||||||
def extra_metrics(self):
|
def extra_metrics(self):
|
||||||
return self.metrics
|
return self.metrics
|
||||||
|
|
||||||
|
def clear_metrics(self):
|
||||||
|
self.metrics = []
|
||||||
|
|
||||||
|
|
||||||
def get_basic_criterion_for_name(name, device):
|
def get_basic_criterion_for_name(name, device):
|
||||||
if name == 'l1':
|
if name == 'l1':
|
||||||
|
@ -192,7 +195,6 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
self.losses_computed = 0
|
self.losses_computed = 0
|
||||||
|
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
self.metrics = []
|
|
||||||
real = extract_params_from_state(self.opt['real'], state)
|
real = extract_params_from_state(self.opt['real'], state)
|
||||||
real = [r.detach() for r in real]
|
real = [r.detach() for r in real]
|
||||||
fake = extract_params_from_state(self.opt['fake'], state)
|
fake = extract_params_from_state(self.opt['fake'], state)
|
||||||
|
@ -258,7 +260,6 @@ class GeometricSimilarityGeneratorLoss(ConfigurableLoss):
|
||||||
(functools.partial(torch.rot90, k=3, dims=[2,3]), functools.partial(torch.rot90, k=1, dims=[2,3]))])
|
(functools.partial(torch.rot90, k=3, dims=[2,3]), functools.partial(torch.rot90, k=1, dims=[2,3]))])
|
||||||
|
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
self.metrics = []
|
|
||||||
net = self.env['generators'][self.generator] # Get the network from an explicit parameter.
|
net = self.env['generators'][self.generator] # Get the network from an explicit parameter.
|
||||||
# The <net> parameter is not reliable for generator losses since often they are combined with many networks.
|
# The <net> parameter is not reliable for generator losses since often they are combined with many networks.
|
||||||
fake = extract_params_from_state(self.opt['fake'], state)
|
fake = extract_params_from_state(self.opt['fake'], state)
|
||||||
|
@ -305,7 +306,6 @@ class TranslationInvarianceLoss(ConfigurableLoss):
|
||||||
assert(self.patch_size > self.overlap)
|
assert(self.patch_size > self.overlap)
|
||||||
|
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
self.metrics = []
|
|
||||||
net = self.env['generators'][self.generator] # Get the network from an explicit parameter.
|
net = self.env['generators'][self.generator] # Get the network from an explicit parameter.
|
||||||
# The <net> parameter is not reliable for generator losses since often they are combined with many networks.
|
# The <net> parameter is not reliable for generator losses since often they are combined with many networks.
|
||||||
|
|
||||||
|
@ -356,7 +356,6 @@ class RecursiveInvarianceLoss(ConfigurableLoss):
|
||||||
assert(self.recursive_depth > 0)
|
assert(self.recursive_depth > 0)
|
||||||
|
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
self.metrics = []
|
|
||||||
net = self.env['generators'][self.generator] # Get the network from an explicit parameter.
|
net = self.env['generators'][self.generator] # Get the network from an explicit parameter.
|
||||||
# The <net> parameter is not reliable for generator losses since they can be combined with many networks.
|
# The <net> parameter is not reliable for generator losses since they can be combined with many networks.
|
||||||
|
|
||||||
|
|
|
@ -156,6 +156,7 @@ class ConfigurableStep(Module):
|
||||||
self.loss_accumulator.add_loss(loss_name, l)
|
self.loss_accumulator.add_loss(loss_name, l)
|
||||||
for n, v in loss.extra_metrics():
|
for n, v in loss.extra_metrics():
|
||||||
self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)
|
self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)
|
||||||
|
loss.clear_metrics()
|
||||||
|
|
||||||
# In some cases, the loss could not be set (e.g. all losses have 'after'
|
# In some cases, the loss could not be set (e.g. all losses have 'after'
|
||||||
if isinstance(total_loss, torch.Tensor):
|
if isinstance(total_loss, torch.Tensor):
|
||||||
|
|
|
@ -61,7 +61,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env)
|
super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env)
|
||||||
self.flow = opt['flow_network']
|
self.flow = opt['flow_network']
|
||||||
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
|
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
|
||||||
self.output_hq_index = opt['output_hq_index'] if 'output_index' in opt.keys() else 0
|
self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0
|
||||||
self.recurrent_index = opt['recurrent_index']
|
self.recurrent_index = opt['recurrent_index']
|
||||||
self.scale = opt['scale']
|
self.scale = opt['scale']
|
||||||
self.resample = Resample2d()
|
self.resample = Resample2d()
|
||||||
|
@ -271,3 +271,5 @@ class PingPongLoss(ConfigurableLoss):
|
||||||
for i in range(cnt):
|
for i in range(cnt):
|
||||||
img = imglist[:, i]
|
img = imglist[:, i]
|
||||||
torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, )))
|
torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, )))
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user