From 1c44d395af74b9b11167900cd2f9139f4cda1c78 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 7 Oct 2020 09:03:30 -0600 Subject: [PATCH] Tecogan work Its training! There's still probably plenty of bugs though.. --- codes/models/networks.py | 14 ++- codes/models/steps/losses.py | 24 ++++ codes/models/steps/tecogan_losses.py | 158 +++++++++++++++------------ 3 files changed, 119 insertions(+), 77 deletions(-) diff --git a/codes/models/networks.py b/codes/models/networks.py index 530e57da..3f5ee5e1 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -1,3 +1,4 @@ +import munch import torch import logging from munch import munchify @@ -14,7 +15,6 @@ import models.archs.rcan as rcan from collections import OrderedDict import torchvision import functools -from models.flownet2.models import FlowNet2 logger = logging.getLogger('base') @@ -86,20 +86,24 @@ def define_G(opt, net_key='network_G', scale=None): init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) elif which_model == 'stacked_switches': xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 - netG = ssg.StackedSwitchGenerator(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], + in_nc = opt_net['in_nc'] if 'in_nc' in opt_net.keys() else 3 + netG = ssg.StackedSwitchGenerator(in_nc=in_nc, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) elif which_model == 'stacked_switches_5lyr': xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 - netG = ssg.StackedSwitchGenerator5Layer(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], + in_nc = opt_net['in_nc'] if 'in_nc' in opt_net.keys() else 3 + netG = ssg.StackedSwitchGenerator5Layer(in_nc=in_nc, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) elif which_model == 'ssg_deep': xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 netG = ssg.SSGDeep(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) elif which_model == "flownet2": - args_dict = {} - args = munchify(args_dict) + from models.flownet2.models import FlowNet2 + ld = torch.load(opt_net['load_path']) + args = munch.Munch({'fp16': False, 'rgb_max': 1.0}) netG = FlowNet2(args) + netG.load_state_dict(ld['state_dict']) elif which_model == "backbone_encoder": netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet']) elif which_model == "backbone_encoder_no_ref": diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index c6644d39..2178aa3e 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -28,6 +28,8 @@ def create_loss(opt_loss, env): return TranslationInvarianceLoss(opt_loss, env) elif type == 'recursive': return RecursiveInvarianceLoss(opt_loss, env) + elif type == 'recurrent': + return RecurrentLoss(opt_loss, env) else: raise NotImplementedError @@ -372,3 +374,25 @@ class RecursiveInvarianceLoss(ConfigurableLoss): else: return self.criterion(compare_real, compare_fake) + +# Loss that pulls tensors from dim 1 of the input and repeatedly feeds them into the +# 'subtype' loss. +class RecurrentLoss(ConfigurableLoss): + def __init__(self, opt, env): + super(RecurrentLoss, self).__init__(opt, env) + o = opt.copy() + o['type'] = opt['subtype'] + o['fake'] = '_fake' + o['real'] = '_real' + self.loss = create_loss(o, self.env) + + def forward(self, net, state): + total_loss = 0 + st = state.copy() + real = state[self.opt['real']] + for i in range(real.shape[1]): + st['_real'] = real[:, i] + st['_fake'] = state[self.opt['fake']][i] + total_loss += self.loss(net, st) + return total_loss + diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index a184f51f..a9ca6d1f 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -1,29 +1,52 @@ -from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state +from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name from models.layers.resample2d_package.resample2d import Resample2d from models.steps.recurrent import RecurrentController from models.steps.injectors import Injector import torch +import torch.nn.functional as F import os import os.path as osp import torchvision def create_teco_loss(opt, env): type = opt['type'] - if type == 'teco_generator_gan': - return TecoGanGeneratorLoss(opt, env) - elif type == 'teco_discriminator_gan': - return TecoGanDiscriminatorLoss(opt, env) + if type == 'teco_gan': + return TecoGanLoss(opt, env) elif type == "teco_pingpong": return PingPongLoss(opt, env) return None -def create_teco_discriminator_sextuplet(input_list, index, flow_gen, resampler): - triplet = input_list[index:index+3] - first_flow = flow_gen(triplet[0], triplet[1]) - last_flow = flow_gen(triplet[2], triplet[1]) - flow_triplet = [resampler(triplet[0], first_flow), triplet[1], resampler(triplet[2], last_flow)] - return torch.cat(triplet + flow_triplet, dim=1) +def create_teco_injector(opt, env): + type = opt['type'] + if type == 'teco_recurrent_generated_sequence_injector': + return RecurrentImageGeneratorSequenceInjector(opt, env) + return None +def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler): + triplet = input_list[:, index:index+3] + # Flow is interpreted from the LR images so that the generator cannot learn to manipulate it. + with torch.no_grad(): + first_flow = flow_gen(torch.stack([lr_imgs[:,0], lr_imgs[:,1]], dim=2)) + first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic') + last_flow = flow_gen(torch.stack([lr_imgs[:,2], lr_imgs[:,1]], dim=2)) + last_flow = F.interpolate(last_flow, scale_factor=scale, mode='bicubic') + flow_triplet = [resampler(triplet[:,0].float(), first_flow.float()), + triplet[:,1], + resampler(triplet[:,2].float(), last_flow.float())] + flow_triplet = torch.stack(flow_triplet, dim=2) + combined = torch.cat([triplet, flow_triplet], dim=2) + b, f, c, h, w = combined.shape + return combined.view(b, 3*6, h, w) # 3*6 is essentially an assertion here. + + +def extract_inputs_index(inputs, i): + res = [] + for input in inputs: + if isinstance(input, torch.Tensor): + res.append(input[:, i]) + else: + res.append(input) + return res # Uses a generator to synthesize a sequence of images from [in] and injects the results into a list [out] # Images are fed in sequentially forward and back, resulting in len([out])=2*len([in])-1 (last element is not repeated). @@ -32,32 +55,51 @@ class RecurrentImageGeneratorSequenceInjector(Injector): def __init__(self, opt, env): super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env) self.flow = opt['flow_network'] + self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0 + self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0 + self.scale = opt['scale'] self.resample = Resample2d() def forward(self, state): gen = self.env['generators'][self.opt['generator']] flow = self.env['generators'][self.flow] results = [] - recurrent_input = torch.zeros_like(state[self.input][0]) + inputs = extract_params_from_state(self.input, state) + if not isinstance(inputs, list): + inputs = [inputs] + recurrent_input = torch.zeros_like(inputs[self.input_lq_index][:,0]) # Go forward in the sequence first. first_step = True - for input in state[self.input]: + b, f, c, h, w = inputs[self.input_lq_index].shape + for i in range(f): + input = extract_inputs_index(inputs, i) if first_step: first_step = False else: - flowfield = flow(recurrent_input, input) - recurrent_input = self.resample(recurrent_input, flowfield) - recurrent_input = gen(input, recurrent_input) + with torch.no_grad(): + reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic') + flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2) + flowfield = flow(flow_input) + # Resample does not work in FP16. + recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float()) + input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1) + gen_out = gen(*input) + recurrent_input = gen_out[self.output_hq_index] results.append(recurrent_input) - recurrent_input = self.flow() # Now go backwards, skipping the last element (it's already stored in recurrent_input) - it = reversed(range(len(results) - 1)) + it = reversed(range(f - 1)) for i in it: - flowfield = flow(recurrent_input, results[i]) - recurrent_input = self.resample(recurrent_input, flowfield) - recurrent_input = gen(results[i], recurrent_input) + input = extract_inputs_index(inputs, i) + with torch.no_grad(): + reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic') + flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2) + flowfield = flow(flow_input) + recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float()) + input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1) + gen_out = gen(*input) + recurrent_input = gen_out[self.output_hq_index] results.append(recurrent_input) return {self.output: results} @@ -76,76 +118,48 @@ class RecurrentImageGeneratorSequenceInjector(Injector): # 4) Composes the three base image and the 2 warped images and middle image into a tensor concatenated at the filter dimension for both real and fake, resulting in a bx18xhxw shape tensor. # 5) Feeds the catted real and fake image sets into the discriminator, computes a loss, and backward(). # 6) Repeat from (1) until all triplets from the real sequence have been exhausted. -class TecoGanDiscriminatorLoss(ConfigurableLoss): +class TecoGanLoss(ConfigurableLoss): def __init__(self, opt, env): - super(TecoGanDiscriminatorLoss, self).__init__(opt, env) - self.opt = opt - self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device']) - self.noise = None if 'noise' not in opt.keys() else opt['noise'] - self.image_flow_generator = opt['image_flow_generator'] - self.resampler = Resample2d() - - def forward(self, net, state): - self.metrics = [] - flow_gen = self.env['generators'][self.image_flow_generator] - real = state[self.opt['real']] - fake = state[self.opt['fake']] - l_total = 0 - for i in range(len(real) - 2): - real_sext = create_teco_discriminator_sextuplet(real, i, flow_gen, self.resampler) - fake_sext = create_teco_discriminator_sextuplet(fake, i, flow_gen, self.resampler) - - d_real = net(real_sext) - d_fake = net(fake_sext) - - if self.opt['gan_type'] in ['gan', 'pixgan']: - self.metrics.append(("d_fake", torch.mean(d_fake))) - self.metrics.append(("d_real", torch.mean(d_real))) - l_real = self.criterion(d_real, True) - l_fake = self.criterion(d_fake, False) - l_total += l_real + l_fake - elif self.opt['gan_type'] == 'ragan': - d_fake_diff = d_fake - torch.mean(d_real) - self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff))) - l_total += (self.criterion(d_real - torch.mean(d_fake), True) + - self.criterion(d_fake_diff, False)) - else: - raise NotImplementedError - return l_total - - -class TecoGanGeneratorLoss(ConfigurableLoss): - def __init__(self, opt, env): - super(TecoGanGeneratorLoss, self).__init__(opt, env) + super(TecoGanLoss, self).__init__(opt, env) self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device']) # TecoGAN parameters + self.scale = opt['scale'] + self.lr_inputs = opt['lr_inputs'] self.image_flow_generator = opt['image_flow_generator'] self.resampler = Resample2d() + self.for_generator = opt['for_generator'] def forward(self, _, state): + net = self.env['discriminators'][self.opt['discriminator']] flow_gen = self.env['generators'][self.image_flow_generator] real = state[self.opt['real']] - fake = state[self.opt['fake']] + fake = torch.stack(state[self.opt['fake']], dim=1) + sequence_len = real.shape[1] + lr = state[self.opt['lr_inputs']] l_total = 0 - for i in range(len(real) - 2): - real_sext = create_teco_discriminator_sextuplet(real, i, flow_gen, self.resampler) - fake_sext = create_teco_discriminator_sextuplet(fake, i, flow_gen, self.resampler) + for i in range(sequence_len - 2): + real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler) + fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler) d_fake = net(fake_sext) - if self.env['step'] % 100 == 0: + if self.for_generator and self.env['step'] % 100 == 0: self.produce_teco_visual_debugs(fake_sext, 'fake', i) self.produce_teco_visual_debugs(real_sext, 'real', i) if self.opt['gan_type'] in ['gan', 'pixgan']: self.metrics.append(("d_fake", torch.mean(d_fake))) - l_fake = self.criterion(d_fake, True) - l_total += l_fake + l_fake = self.criterion(d_fake, self.for_generator) + if not self.for_generator: + l_real = self.criterion(d_real, True) + else: + l_real = 0 + l_total += l_fake + l_real elif self.opt['gan_type'] == 'ragan': d_real = net(real_sext) d_fake_diff = d_fake - torch.mean(d_real) self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff))) - l_total += (self.criterion(d_real - torch.mean(d_fake), False) + - self.criterion(d_fake_diff, True)) + l_total += (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) + + self.criterion(d_fake_diff, self.for_generator)) else: raise NotImplementedError @@ -164,12 +178,12 @@ class PingPongLoss(ConfigurableLoss): def __init__(self, opt, env): super(PingPongLoss, self).__init__(opt, env) self.opt = opt - self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device']) + self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device']) def forward(self, _, state): fake = state[self.opt['fake']] l_total = 0 - for i in range((len(fake) - 1) / 2): + for i in range((len(fake) - 1) // 2): early = fake[i] late = fake[-i] l_total += self.criterion(early, late)