From 731700ab2ccbfe0ffd56c93c3ba044097cf84e3f Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 12 Oct 2020 17:43:28 -0600 Subject: [PATCH 1/2] checkpoint in ssg --- codes/models/archs/SwitchedResidualGenerator_arch.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index b79fd55f..a753e407 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -75,6 +75,7 @@ def gather_2d(input, index): return result +from utils.util import checkpoint class ConfigurableSwitchComputer(nn.Module): def __init__(self, base_filters, multiplexer_net, pre_transform_block, transform_block, transform_count, attention_norm, init_temp=20, add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=False): @@ -131,13 +132,13 @@ class ConfigurableSwitchComputer(nn.Module): x = self.pre_transform(*x) if not isinstance(x, tuple): x = (x,) - xformed = [t(*x) for t in self.transforms] + xformed = [checkpoint(t, *x) for t in self.transforms] if not isinstance(att_in, tuple): att_in = (att_in,) if self.feed_transforms_into_multiplexer: att_in = att_in + (torch.stack(xformed, dim=1),) - m = self.multiplexer(*att_in) + m = checkpoint(self.multiplexer, *att_in) # It is assumed that [xformed] and [m] are collapsed into tensors at this point. outputs, attention = self.switch(xformed, m, True, self.update_norm) @@ -603,4 +604,4 @@ if __name__ == '__main__': trans = [torch.randn(4,64,64,64) for t in range(10)] b = bb(x, r, cp) - emb(xu, b, trans) \ No newline at end of file + emb(xu, b, trans) From 8014f050ac2488f87e5ee025259ef64b675a1f6f Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 13 Oct 2020 10:07:49 -0600 Subject: [PATCH 2/2] Clear metrics properly Holy cow, what a PITA bug. --- codes/models/steps/losses.py | 7 +++---- codes/models/steps/steps.py | 1 + codes/models/steps/tecogan_losses.py | 4 +++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 3e6eeb6f..0e37e2c7 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -64,6 +64,9 @@ class ConfigurableLoss(nn.Module): def extra_metrics(self): return self.metrics + def clear_metrics(self): + self.metrics = [] + def get_basic_criterion_for_name(name, device): if name == 'l1': @@ -192,7 +195,6 @@ class DiscriminatorGanLoss(ConfigurableLoss): self.losses_computed = 0 def forward(self, net, state): - self.metrics = [] real = extract_params_from_state(self.opt['real'], state) real = [r.detach() for r in real] fake = extract_params_from_state(self.opt['fake'], state) @@ -258,7 +260,6 @@ class GeometricSimilarityGeneratorLoss(ConfigurableLoss): (functools.partial(torch.rot90, k=3, dims=[2,3]), functools.partial(torch.rot90, k=1, dims=[2,3]))]) def forward(self, net, state): - self.metrics = [] net = self.env['generators'][self.generator] # Get the network from an explicit parameter. # The parameter is not reliable for generator losses since often they are combined with many networks. fake = extract_params_from_state(self.opt['fake'], state) @@ -305,7 +306,6 @@ class TranslationInvarianceLoss(ConfigurableLoss): assert(self.patch_size > self.overlap) def forward(self, net, state): - self.metrics = [] net = self.env['generators'][self.generator] # Get the network from an explicit parameter. # The parameter is not reliable for generator losses since often they are combined with many networks. @@ -356,7 +356,6 @@ class RecursiveInvarianceLoss(ConfigurableLoss): assert(self.recursive_depth > 0) def forward(self, net, state): - self.metrics = [] net = self.env['generators'][self.generator] # Get the network from an explicit parameter. # The parameter is not reliable for generator losses since they can be combined with many networks. diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index f7c9ca66..6cf21b09 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -156,6 +156,7 @@ class ConfigurableStep(Module): self.loss_accumulator.add_loss(loss_name, l) for n, v in loss.extra_metrics(): self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v) + loss.clear_metrics() # In some cases, the loss could not be set (e.g. all losses have 'after' if isinstance(total_loss, torch.Tensor): diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index 9f320506..c95e214b 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -61,7 +61,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector): super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env) self.flow = opt['flow_network'] self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0 - self.output_hq_index = opt['output_hq_index'] if 'output_index' in opt.keys() else 0 + self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0 self.recurrent_index = opt['recurrent_index'] self.scale = opt['scale'] self.resample = Resample2d() @@ -271,3 +271,5 @@ class PingPongLoss(ConfigurableLoss): for i in range(cnt): img = imglist[:, i] torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, ))) + +