diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index f718de9f..51801827 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -31,8 +31,6 @@ def create_injector(opt_inject, env): return GreyInjector(opt_inject, env) elif type == 'interpolate': return InterpolateInjector(opt_inject, env) - elif type == 'imageflow': - return ImageFlowInjector(opt_inject, env) elif type == 'image_patch': return ImagePatchInjector(opt_inject, env) elif type == 'concatenate': diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index b460a15d..100ee9fa 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -26,23 +26,6 @@ def create_teco_injector(opt, env): return FlowAdjustment(opt, env) return None -def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin): - # Flow is interpreted from the LR images so that the generator cannot learn to manipulate it. - with autocast(enabled=False): - triplet = input_list[:, index:index+3].float() - first_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,0]], dim=2)) - last_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,2]], dim=2)) - flow_triplet = [resampler(triplet[:,0], first_flow), - triplet[:,1], - resampler(triplet[:,2], last_flow)] - flow_triplet = torch.stack(flow_triplet, dim=1) - combined = torch.cat([triplet, flow_triplet], dim=1) - b, f, c, h, w = combined.shape - combined = combined.view(b, 3*6, h, w) # 3*6 is essentially an assertion here. - # Apply margin - return combined[:, :, margin:-margin, margin:-margin] - - def extract_inputs_index(inputs, i): res = [] for input in inputs: @@ -152,9 +135,14 @@ class RecurrentImageGeneratorSequenceInjector(Injector): results[out_key].append(gen_out[i]) recurrent_input = gen_out[self.output_hq_index] + final_results = {} + # Include 'hq_batched' here - because why not... Don't really need a separate injector for this. + b, s, c, h, w = state['hq'].shape + final_results['hq_batched'] = state['hq'].view(b*s, c, h, w) for k, v in results.items(): - results[k] = torch.stack(v, dim=1) - return results + final_results[k] = torch.stack(v, dim=1) + final_results[k + "_batched"] = torch.cat(v[:s], dim=0) # Only include the original sequence - this output is generally used to compare against HQ. + return final_results def produce_teco_visual_debugs(self, gen_input, gen_recurrent, it): if self.env['rank'] > 0: @@ -183,6 +171,47 @@ class FlowAdjustment(Injector): return {self.output: self.resample(state[self.flowed], flowfield)} +def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin): + # Flow is interpreted from the LR images so that the generator cannot learn to manipulate it. + with autocast(enabled=False): + triplet = input_list[:, index:index+3].float() + first_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,0]], dim=2)) + last_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,2]], dim=2)) + flow_triplet = [resampler(triplet[:,0], first_flow), + triplet[:,1], + resampler(triplet[:,2], last_flow)] + flow_triplet = torch.stack(flow_triplet, dim=1) + combined = torch.cat([triplet, flow_triplet], dim=1) + b, f, c, h, w = combined.shape + combined = combined.view(b, 3*6, h, w) # 3*6 is essentially an assertion here. + # Apply margin + return combined[:, :, margin:-margin, margin:-margin] + + +def create_all_discriminator_sextuplets(input_list, lr_imgs, scale, total, flow_gen, resampler, margin): + # Combine everything and feed it into the flow network at once for better efficiency. + batch_sz = input_list.shape[0] + flux_doubles_forward = [torch.stack([input_list[:,i], input_list[:,i+1]], dim=2) for i in range(1, total+1)] + flux_doubles_backward = [torch.stack([input_list[:,i], input_list[:,i-1]], dim=2) for i in range(1, total+1)] + flows_forward = flow_gen(torch.cat(flux_doubles_forward, dim=0)) + flows_backward = flow_gen(torch.cat(flux_doubles_backward, dim=0)) + sexts = [] + for i in range(total): + flow_forward = flows_forward[batch_sz*i:batch_sz*(i+1)] + flow_backward = flows_backward[batch_sz*i:batch_sz*(i+1)] + mid = input_list[:,i+1] + sext = torch.stack([input_list[:,i], mid, input_list[:,i+2], + resampler(mid, flow_backward), + mid, + resampler(mid, flow_forward)], dim=1) + # Apply margin + b, f, c, h, w = sext.shape + sext = sext.view(b, 3*6, h, w) # f*c = 6*3 + sext = sext[:, :, margin:-margin, margin:-margin] + sexts.append(sext) + return torch.cat(sexts, dim=0) + + # This is the temporal discriminator loss from TecoGAN. # # It has a strict contract for 'real' and 'fake' inputs: @@ -208,48 +237,85 @@ class TecoGanLoss(ConfigurableLoss): self.for_generator = opt['for_generator'] self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0 self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors. + self.ff = opt['fast_forward'] if 'fast_forward' in opt.keys() else False def forward(self, _, state): - fp16 = self.env['opt']['fp16'] - net = self.env['discriminators'][self.opt['discriminator']] + if self.ff: + return self.fast_forward(state) + else: + return self.lowmem_forward(state) + + + # Computes the discriminator loss one recursive step at a time, which has a lower memory overhead but is + # slower. + def lowmem_forward(self, state): flow_gen = self.env['generators'][self.image_flow_generator] real = state[self.opt['real']] fake = state[self.opt['fake']] sequence_len = real.shape[1] lr = state[self.opt['lr_inputs']] l_total = 0 + + # Create a list of all the discriminator inputs, which will be reduced into the batch dim for efficient computation. for i in range(sequence_len - 2): real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin) fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin) - with autocast(enabled=fp16): - d_fake = net(fake_sext) - d_real = net(real_sext) - self.metrics.append(("d_fake", torch.mean(d_fake))) - self.metrics.append(("d_real", torch.mean(d_real))) - - if self.for_generator and self.env['step'] % 50 == 0: - self.produce_teco_visual_debugs(fake_sext, 'fake', i) - self.produce_teco_visual_debugs(real_sext, 'real', i) - - if self.opt['gan_type'] in ['gan', 'pixgan']: - l_fake = self.criterion(d_fake, self.for_generator) - if not self.for_generator: - l_real = self.criterion(d_real, True) - else: - l_real = 0 - l_step = l_fake + l_real - elif self.opt['gan_type'] == 'ragan': - d_fake_diff = d_fake - torch.mean(d_real) - self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff))) - l_step = (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) + - self.criterion(d_fake_diff, self.for_generator)) - else: - raise NotImplementedError + l_step = self.compute_loss(real_sext, fake_sext) if l_step > self.min_loss: l_total += l_step return l_total + # Computes the discriminator loss by dogpiling all of the sextuplets into the batch dimension and doing one massive + # forward() on the discriminators. High memory but faster. + def fast_forward(self, state): + flow_gen = self.env['generators'][self.image_flow_generator] + real = state[self.opt['real']] + fake = state[self.opt['fake']] + sequence_len = real.shape[1] + lr = state[self.opt['lr_inputs']] + + # Create a list of all the discriminator inputs, which will be reduced into the batch dim for efficient computation. + combined_real_sext = create_all_discriminator_sextuplets(real, lr, self.scale, sequence_len - 2, flow_gen, + self.resampler, self.margin) + combined_fake_sext = create_all_discriminator_sextuplets(fake, lr, self.scale, sequence_len - 2, flow_gen, + self.resampler, self.margin) + l_total = self.compute_loss(combined_real_sext, combined_fake_sext) + if l_total < self.min_loss: + l_total = 0 + return l_total + + def compute_loss(self, real_sext, fake_sext): + fp16 = self.env['opt']['fp16'] + net = self.env['discriminators'][self.opt['discriminator']] + with autocast(enabled=fp16): + d_fake = net(fake_sext) + d_real = net(real_sext) + + self.metrics.append(("d_fake", torch.mean(d_fake))) + self.metrics.append(("d_real", torch.mean(d_real))) + + if self.for_generator and self.env['step'] % 50 == 0: + self.produce_teco_visual_debugs(fake_sext, 'fake', 0) + self.produce_teco_visual_debugs(real_sext, 'real', 0) + + if self.opt['gan_type'] in ['gan', 'pixgan']: + l_fake = self.criterion(d_fake, self.for_generator) + if not self.for_generator: + l_real = self.criterion(d_real, True) + else: + l_real = 0 + l_step = l_fake + l_real + elif self.opt['gan_type'] == 'ragan': + d_fake_diff = d_fake - torch.mean(d_real) + self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff))) + l_step = (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) + + self.criterion(d_fake_diff, self.for_generator)) + else: + raise NotImplementedError + + return l_step + def produce_teco_visual_debugs(self, sext, lbl, it): if self.env['rank'] > 0: return @@ -291,4 +357,3 @@ class PingPongLoss(ConfigurableLoss): img = imglist[:, i] torchvision.utils.save_image(img.float(), osp.join(base_path, "%s.png" % (i, ))) -