From 6d0490a0e617c78215ed8408ef2fb48857f64bd5 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 25 Sep 2020 16:38:23 -0600 Subject: [PATCH] Tecogan implementation work --- codes/models/steps/steps.py | 4 + ..._gen_injectors.py => tecogan_injectors.py} | 25 ++++--- codes/models/steps/tecogan_losses.py | 74 +++++++++++++++++++ 3 files changed, 94 insertions(+), 9 deletions(-) rename codes/models/steps/{recursive_gen_injectors.py => tecogan_injectors.py} (53%) create mode 100644 codes/models/steps/tecogan_losses.py diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index 84d635b9..72086f47 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -109,6 +109,10 @@ class ConfigurableStep(Module): local_state.update(new_state) local_state['train_nets'] = str(self.get_networks_trained()) + # Some losses compute backward() internally. Accomodate this by stashing the amp_loss_id in env. + self.env['amp_loss_id'] = amp_loss_id + self.env['current_step_optimizers'] = self.optimizers + # Inject in any extra dependencies. for inj in self.injectors: # Don't do injections tagged with eval unless we are not in train mode. diff --git a/codes/models/steps/recursive_gen_injectors.py b/codes/models/steps/tecogan_injectors.py similarity index 53% rename from codes/models/steps/recursive_gen_injectors.py rename to codes/models/steps/tecogan_injectors.py index 6b97ff3b..981ef8af 100644 --- a/codes/models/steps/recursive_gen_injectors.py +++ b/codes/models/steps/tecogan_injectors.py @@ -1,24 +1,31 @@ import models.steps.injectors as injectors +import torch # Uses a generator to synthesize a sequence of images from [in] and injects the results into a list [out] -# All results are checkpointed for memory savings. Recurrent inputs are also detached before being fed back into -# the generator. +# Images are fed in sequentially forward and back, resulting in len([out])=2*len([in])-1 (last element is not repeated). +# All computation is done with torch.no_grad(). class RecurrentImageGeneratorSequenceInjector(injectors.Injector): def __init__(self, opt, env): super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env) def forward(self, state): gen = self.env['generators'][self.opt['generator']] - new_state = {} results = [] - recurrent_input = torch.zeros_like(state[self.input][0]) - for input in state[self.input]: - result = checkpoint(gen, input, recurrent_input) - results.append(result) - recurrent_input = result.detach() + with torch.no_grad(): + recurrent_input = torch.zeros_like(state[self.input][0]) + # Go forward in the sequence first. + for input in state[self.input]: + recurrent_input = gen(input, recurrent_input) + results.append(recurrent_input) - new_state = {self.output: results} + # Now go backwards, skipping the last element (it's already stored in recurrent_input) + it = reversed(range(len(results) - 1)) + for i in it: + recurrent_input = gen(results[i], recurrent_input) + results.append(recurrent_input) + + new_state = {self.output: results} return new_state diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py new file mode 100644 index 00000000..9ef91215 --- /dev/null +++ b/codes/models/steps/tecogan_losses.py @@ -0,0 +1,74 @@ +from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state +from models.layers.resample2d_package.resample2d import Resample2d +import torch +from apex import amp + + +def create_teco_discriminator_sextuplet(input_list, index, flow_gen, resampler, detach=True): + triplet = input_list[index:index+3] + first_flow = flow_gen(triplet[1], triplet[0]) + last_flow = flow_gen(triplet[1], triplet[2]) + if detach: + first_flow = first_flow.detach() + last_flow = last_flow.detach() + flow_triplet = [resampler(triplet[0], first_flow), triplet[1], resampler(triplet[2], last_flow)] + return torch.cat(triplet + flow_triplet, dim=1) + +# This is the temporal discriminator loss from TecoGAN. +# +# It has a strict contact for 'real' and 'fake' inputs: +# 'real' - Must be a list of arbitrary images (len>3) drawn from the dataset +# 'fake' - The output of the RecurrentImageGeneratorSequenceInjector for the same set of images. +# +# This loss does the following: +# 1) Picks an image triplet, starting with the first '3' elements in 'real' and 'fake'. +# 2) Uses the image flow generator (specified with 'image_flow_generator') to create detached flow fields for the first and last images in the above sequence. +# 3) Warps the first and last images according to the flow field. +# 4) Composes the three base image and the 2 warped images and middle image into a tensor concatenated at the filter dimension for both real and fake, resulting in a bx18xhxw shape tensor. +# 5) Feeds the catted real and fake image sets into the discriminator, computes a loss, and backward(). +# 6) Repeat from (1) until all triplets from the real sequence have been exhausted. +# +# Note: All steps before 'discriminator_flow_after' do not use triplets. Instead, they use a single image repeated 6 times across the filter dimension. +class TecoGanDiscriminatorLoss(ConfigurableLoss): + def __init__(self, opt, env): + super(TecoGanDiscriminatorLoss, self).__init__(opt, env) + self.opt = opt + self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device']) + self.discriminator_flow_after = opt['discriminator_flow_after'] + self.image_flow_generator = opt['image_flow_generator'] + self.resampler = Resample2d() + + def forward(self, net, state): + self.metrics = [] + flow_gen = self.env['generators'][self.image_flow_generator] + real = state[self.opt['real']] + fake = state[self.opt['fake']] + backwards_count = range(len(real)-2) + for i in range(len(real) - 2): + if self.env[''] + real_sext = create_teco_discriminator_sextuplet(real, i, flow_gen, self.resampler) + fake_sext = create_teco_discriminator_sextuplet(fake, i, flow_gen, self.resampler) + + d_real = net(real_sext) + d_fake = net(fake_sext) + + if self.opt['gan_type'] in ['gan', 'pixgan']: + self.metrics.append(("d_fake", torch.mean(d_fake))) + self.metrics.append(("d_real", torch.mean(d_real))) + l_real = self.criterion(d_real, True) + l_fake = self.criterion(d_fake, False) + l_total = l_real + l_fake + elif self.opt['gan_type'] == 'ragan': + d_fake_diff = d_fake - torch.mean(d_real) + self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff))) + l_total = (self.criterion(d_real - torch.mean(d_fake), True) + + self.criterion(d_fake_diff, False)) + else: + raise NotImplementedError + + l_total = l_total / backwards_count + if self.env['amp']: + with amp.scale_loss(l_total, self.env['current_step_optimizers'][0], self.env['amp_loss_id']) as loss: + loss.backward() + else: + l_total.backward() \ No newline at end of file