From 3a5b23b9f7a15ea85fdcfb99e5f86abdafb8521f Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 10 Oct 2020 20:21:09 -0600 Subject: [PATCH] Alter teco_losses to feed a recurrent input in as separate --- codes/models/steps/tecogan_losses.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index 5565b749..46e5335a 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -62,6 +62,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector): self.flow = opt['flow_network'] self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0 self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0 + self.recurrent_index = opt['recurrent_index'] self.scale = opt['scale'] self.resample = Resample2d() self.first_inputs = opt['first_inputs'] if 'first_inputs' in opt.keys() else opt['in'] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'. @@ -83,17 +84,17 @@ class RecurrentImageGeneratorSequenceInjector(Injector): for i in range(f): if first_step: input = extract_inputs_index(first_inputs, i) - recurrent_input = input[self.input_lq_index] + recurrent_input = torch.zeros_like(input[self.recurrent_index]) first_step = False else: input = extract_inputs_index(inputs, i) with torch.no_grad(): reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic') flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2) - flowfield = flow(flow_input) + flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic') # Resample does not work in FP16. recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float()) - input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1) + input[self.recurrent_index] = recurrent_input if self.env['step'] % 50 == 0: self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index) debug_index += 1 @@ -111,9 +112,10 @@ class RecurrentImageGeneratorSequenceInjector(Injector): with torch.no_grad(): reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic') flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2) - flowfield = flow(flow_input) + flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic') recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float()) - input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1) + input[self.recurrent_index + ] = recurrent_input if self.env['step'] % 50 == 0: self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index) debug_index += 1 @@ -126,7 +128,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector): return {self.output: results} def produce_teco_visual_debugs(self, gen_input, it): - if dist.get_rank() > 0: + if self.env['rank'] > 0: return base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step'])) os.makedirs(base_path, exist_ok=True) @@ -174,6 +176,7 @@ class TecoGanLoss(ConfigurableLoss): self.image_flow_generator = opt['image_flow_generator'] self.resampler = Resample2d() self.for_generator = opt['for_generator'] + self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0 self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors. def forward(self, _, state): @@ -202,19 +205,21 @@ class TecoGanLoss(ConfigurableLoss): l_real = self.criterion(d_real, True) else: l_real = 0 - l_total += l_fake + l_real + l_step = l_fake + l_real elif self.opt['gan_type'] == 'ragan': d_fake_diff = d_fake - torch.mean(d_real) self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff))) - l_total += (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) + + l_step = (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) + self.criterion(d_fake_diff, self.for_generator)) else: raise NotImplementedError + if l_step > self.min_loss: + l_total += l_step return l_total def produce_teco_visual_debugs(self, sext, lbl, it): - if dist.get_rank() > 0: + if self.env['rank'] > 0: return base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl) os.makedirs(base_path, exist_ok=True) @@ -244,7 +249,7 @@ class PingPongLoss(ConfigurableLoss): return l_total def produce_teco_visual_debugs(self, imglist): - if dist.get_rank() > 0: + if self.env['rank'] > 0: return base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step'])) os.makedirs(base_path, exist_ok=True)