From c93dd623d70907ae5b9050f5f5b26b15e59200d6 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 7 Oct 2020 23:11:58 -0600 Subject: [PATCH] Tecogan losses work --- codes/models/steps/tecogan_losses.py | 55 +++++++++++++++------------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index e6c6fa92..5857a5f5 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -26,10 +26,10 @@ def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_ triplet = input_list[:, index:index+3] # Flow is interpreted from the LR images so that the generator cannot learn to manipulate it. with torch.no_grad(): - first_flow = flow_gen(torch.stack([lr_imgs[:,1], lr_imgs[:,0]], dim=2)) - first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic') - last_flow = flow_gen(torch.stack([lr_imgs[:,1], lr_imgs[:,2]], dim=2)) - last_flow = F.interpolate(last_flow, scale_factor=scale, mode='bicubic') + first_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,0]], dim=2).float()) + #first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic') + last_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,2]], dim=2).float()) + #last_flow = F.interpolate(last_flow, scale_factor=scale, mode='bicubic') flow_triplet = [resampler(triplet[:,0].float(), first_flow.float()), triplet[:,1], resampler(triplet[:,2].float(), last_flow.float())] @@ -61,25 +61,29 @@ class RecurrentImageGeneratorSequenceInjector(Injector): self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0 self.scale = opt['scale'] self.resample = Resample2d() + self.first_inputs = opt['first_inputs'] if 'first_inputs' in opt.keys() else opt['in'] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'. + self.do_backwards = opt['do_backwards'] if 'do_backwards' in opt.keys() else True def forward(self, state): gen = self.env['generators'][self.opt['generator']] flow = self.env['generators'][self.flow] results = [] + first_inputs = extract_params_from_state(self.first_inputs, state) inputs = extract_params_from_state(self.input, state) if not isinstance(inputs, list): inputs = [inputs] - recurrent_input = torch.zeros_like(inputs[self.input_lq_index][:,0]) # Go forward in the sequence first. first_step = True b, f, c, h, w = inputs[self.input_lq_index].shape debug_index = 0 for i in range(f): - input = extract_inputs_index(inputs, i) if first_step: + input = extract_inputs_index(first_inputs, i) + recurrent_input = input[self.input_lq_index] first_step = False else: + input = extract_inputs_index(inputs, i) with torch.no_grad(): reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic') flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2) @@ -87,7 +91,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector): # Resample does not work in FP16. recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float()) input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1) - if self.env['step'] % 20 == 0: + if self.env['step'] % 50 == 0: self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index) debug_index += 1 gen_out = gen(*input) @@ -95,21 +99,22 @@ class RecurrentImageGeneratorSequenceInjector(Injector): results.append(recurrent_input) # Now go backwards, skipping the last element (it's already stored in recurrent_input) - it = reversed(range(f - 1)) - for i in it: - input = extract_inputs_index(inputs, i) - with torch.no_grad(): - reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic') - flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2) - flowfield = flow(flow_input) - recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float()) - input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1) - if self.env['step'] % 20 == 0: - self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index) - debug_index += 1 - gen_out = gen(*input) - recurrent_input = gen_out[self.output_hq_index] - results.append(recurrent_input) + if self.do_backwards: + it = reversed(range(f - 1)) + for i in it: + input = extract_inputs_index(inputs, i) + with torch.no_grad(): + reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic') + flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2) + flowfield = flow(flow_input) + recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float()) + input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1) + if self.env['step'] % 50 == 0: + self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index) + debug_index += 1 + gen_out = gen(*input) + recurrent_input = gen_out[self.output_hq_index] + results.append(recurrent_input) return {self.output: results} @@ -122,7 +127,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector): # This is the temporal discriminator loss from TecoGAN. # -# It has a strict contact for 'real' and 'fake' inputs: +# It has a strict contract for 'real' and 'fake' inputs: # 'real' - Must be a list of arbitrary images (len>3) drawn from the dataset # 'fake' - The output of the RecurrentImageGeneratorSequenceInjector for the same set of images. # @@ -161,7 +166,7 @@ class TecoGanLoss(ConfigurableLoss): self.metrics.append(("d_fake", torch.mean(d_fake))) self.metrics.append(("d_real", torch.mean(d_real))) - if self.for_generator and self.env['step'] % 20 == 0: + if self.for_generator and self.env['step'] % 50 == 0: self.produce_teco_visual_debugs(fake_sext, 'fake', i) self.produce_teco_visual_debugs(real_sext, 'real', i) @@ -205,7 +210,7 @@ class PingPongLoss(ConfigurableLoss): late = fake[-i] l_total += self.criterion(early, late) - if self.env['step'] % 20 == 0: + if self.env['step'] % 50 == 0: self.produce_teco_visual_debugs(fake) return l_total