diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index 5565b749..46e5335a 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -62,6 +62,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector): self.flow = opt['flow_network'] self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0 self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0 + self.recurrent_index = opt['recurrent_index'] self.scale = opt['scale'] self.resample = Resample2d() self.first_inputs = opt['first_inputs'] if 'first_inputs' in opt.keys() else opt['in'] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'. @@ -83,17 +84,17 @@ class RecurrentImageGeneratorSequenceInjector(Injector): for i in range(f): if first_step: input = extract_inputs_index(first_inputs, i) - recurrent_input = input[self.input_lq_index] + recurrent_input = torch.zeros_like(input[self.recurrent_index]) first_step = False else: input = extract_inputs_index(inputs, i) with torch.no_grad(): reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic') flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2) - flowfield = flow(flow_input) + flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic') # Resample does not work in FP16. recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float()) - input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1) + input[self.recurrent_index] = recurrent_input if self.env['step'] % 50 == 0: self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index) debug_index += 1 @@ -111,9 +112,10 @@ class RecurrentImageGeneratorSequenceInjector(Injector): with torch.no_grad(): reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic') flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2) - flowfield = flow(flow_input) + flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic') recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float()) - input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1) + input[self.recurrent_index + ] = recurrent_input if self.env['step'] % 50 == 0: self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index) debug_index += 1 @@ -126,7 +128,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector): return {self.output: results} def produce_teco_visual_debugs(self, gen_input, it): - if dist.get_rank() > 0: + if self.env['rank'] > 0: return base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step'])) os.makedirs(base_path, exist_ok=True) @@ -174,6 +176,7 @@ class TecoGanLoss(ConfigurableLoss): self.image_flow_generator = opt['image_flow_generator'] self.resampler = Resample2d() self.for_generator = opt['for_generator'] + self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0 self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors. def forward(self, _, state): @@ -202,19 +205,21 @@ class TecoGanLoss(ConfigurableLoss): l_real = self.criterion(d_real, True) else: l_real = 0 - l_total += l_fake + l_real + l_step = l_fake + l_real elif self.opt['gan_type'] == 'ragan': d_fake_diff = d_fake - torch.mean(d_real) self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff))) - l_total += (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) + + l_step = (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) + self.criterion(d_fake_diff, self.for_generator)) else: raise NotImplementedError + if l_step > self.min_loss: + l_total += l_step return l_total def produce_teco_visual_debugs(self, sext, lbl, it): - if dist.get_rank() > 0: + if self.env['rank'] > 0: return base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl) os.makedirs(base_path, exist_ok=True) @@ -244,7 +249,7 @@ class PingPongLoss(ConfigurableLoss): return l_total def produce_teco_visual_debugs(self, imglist): - if dist.get_rank() > 0: + if self.env['rank'] > 0: return base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step'])) os.makedirs(base_path, exist_ok=True)