From 840927063a9903e45711d8833e46455f1cdb771d Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 5 Oct 2020 19:35:28 -0600 Subject: [PATCH] Work on tecogan losses --- codes/models/steps/losses.py | 5 +- codes/models/steps/tecogan_custom.py | 154 ------------------------- codes/models/steps/tecogan_losses.py | 162 +++++++++++++++++++++++++++ 3 files changed, 166 insertions(+), 155 deletions(-) delete mode 100644 codes/models/steps/tecogan_custom.py create mode 100644 codes/models/steps/tecogan_losses.py diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index d5caa02f..c6644d39 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -9,7 +9,10 @@ import torchvision def create_loss(opt_loss, env): type = opt_loss['type'] - if type == 'pix': + if 'teco_' in type: + from models.steps.tecogan_losses import create_teco_loss + return create_teco_loss(opt_loss, env) + elif type == 'pix': return PixLoss(opt_loss, env) elif type == 'feature': return FeatureLoss(opt_loss, env) diff --git a/codes/models/steps/tecogan_custom.py b/codes/models/steps/tecogan_custom.py deleted file mode 100644 index ba207f37..00000000 --- a/codes/models/steps/tecogan_custom.py +++ /dev/null @@ -1,154 +0,0 @@ -from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state -from models.layers.resample2d_package.resample2d import Resample2d -from models.steps.recurrent import RecurrentController -from models.steps.injectors import Injector -import torch -from apex import amp - - -def create_teco_discriminator_sextuplet(input_list, index, flow_gen, resampler, detach=True): - triplet = input_list[index:index+3] - first_flow = flow_gen(triplet[1], triplet[0]) - last_flow = flow_gen(triplet[1], triplet[2]) - if detach: - first_flow = first_flow.detach() - last_flow = last_flow.detach() - flow_triplet = [resampler(triplet[0], first_flow), triplet[1], resampler(triplet[2], last_flow)] - return torch.cat(triplet + flow_triplet, dim=1) - - -# Controller class that schedules the recurring inputs of tecogan -class TecoGanController(RecurrentController): - def __init__(self, opt, env): - super(TecoGanController, self).__init__(opt, env) - self.sequence_len = opt['teco_sequence_length'] - - def get_next_step(self, state, recurrent_state): - # The first stage feeds the LR input into both generator inputs. - if recurrent_state is None: - return { - '_gen_lr_input_index': 0, - '_teco_recurrent_counter': 0 - '_teco_stage': 0 - } - # The second stage is truly recurrent, but needs its own stage counter because the temporal discriminator - # cannot come online yet. - elif recurrent_state['_teco_recurrent_counter'] == 1: - return { - '_gen_lr_input_index': 1, - '_teco_stage': 1, - '_teco_recurrent_counter': recurrent_state['_teco_recurrent_counter'] + 1 - } - # The third stage is truly recurrent through the end of the sequence. - elif recurrent_state['_teco_recurrent_counter'] < self.sequence_len: - return { - '_gen_lr_input_index': recurrent_state['_gen_lr_input_index'] + 1, - '_teco_stage': 2, - '_teco_recurrent_counter': recurrent_state['_teco_recurrent_counter'] + 1 - } - # The fourth stage regresses backwards through the sequence. - elif recurrent_state['_teco_recurrent_counter'] < self.sequence_len * 2 - 1: - return { - '_gen_lr_input_index': self.sequence_len - recurrent_state['teco_recurrent_counter'] - 1, - '_teco_stage': 3, - '_teco_recurrent_counter': recurrent_state['_teco_recurrent_counter'] + 1 - } - else: - return None - - -# Uses a generator to synthesize a sequence of images from [in] and injects the results into a list [out] -# Images are fed in sequentially forward and back, resulting in len([out])=2*len([in])-1 (last element is not repeated). -# All computation is done with torch.no_grad(). -class RecurrentImageGeneratorSequenceInjector(Injector): - def __init__(self, opt, env): - super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env) - - def forward(self, state): - gen = self.env['generators'][self.opt['generator']] - results = [] - with torch.no_grad(): - recurrent_input = torch.zeros_like(state[self.input][0]) - # Go forward in the sequence first. - for input in state[self.input]: - recurrent_input = gen(input, recurrent_input) - results.append(recurrent_input) - - # Now go backwards, skipping the last element (it's already stored in recurrent_input) - it = reversed(range(len(results) - 1)) - for i in it: - recurrent_input = gen(results[i], recurrent_input) - results.append(recurrent_input) - - new_state = {self.output: results} - return new_state - - -class ImageFlowInjector(Injector): - def __init__(self, opt, env): - # Requires building this custom cuda kernel. Only require it if explicitly needed. - from models.networks.layers.resample2d_package.resample2d import Resample2d - super(ImageFlowInjector, self).__init__(opt, env) - self.resample = Resample2d() - - def forward(self, state): - return self.resample(state[self.opt['in']], state[self.opt['flow']]) - - -# This is the temporal discriminator loss from TecoGAN. -# -# It has a strict contact for 'real' and 'fake' inputs: -# 'real' - Must be a list of arbitrary images (len>3) drawn from the dataset -# 'fake' - The output of the RecurrentImageGeneratorSequenceInjector for the same set of images. -# -# This loss does the following: -# 1) Picks an image triplet, starting with the first '3' elements in 'real' and 'fake'. -# 2) Uses the image flow generator (specified with 'image_flow_generator') to create detached flow fields for the first and last images in the above sequence. -# 3) Warps the first and last images according to the flow field. -# 4) Composes the three base image and the 2 warped images and middle image into a tensor concatenated at the filter dimension for both real and fake, resulting in a bx18xhxw shape tensor. -# 5) Feeds the catted real and fake image sets into the discriminator, computes a loss, and backward(). -# 6) Repeat from (1) until all triplets from the real sequence have been exhausted. -# -# Note: All steps before 'discriminator_flow_after' do not use triplets. Instead, they use a single image repeated 6 times across the filter dimension. -class TecoGanDiscriminatorLoss(ConfigurableLoss): - def __init__(self, opt, env): - super(TecoGanDiscriminatorLoss, self).__init__(opt, env) - self.opt = opt - self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device']) - self.discriminator_flow_after = opt['discriminator_flow_after'] - self.image_flow_generator = opt['image_flow_generator'] - self.resampler = Resample2d() - - def forward(self, net, state): - self.metrics = [] - flow_gen = self.env['generators'][self.image_flow_generator] - real = state[self.opt['real']] - fake = state[self.opt['fake']] - backwards_count = range(len(real)-2) - for i in range(len(real) - 2): - real_sext = create_teco_discriminator_sextuplet(real, i, flow_gen, self.resampler) - fake_sext = create_teco_discriminator_sextuplet(fake, i, flow_gen, self.resampler) - - d_real = net(real_sext) - d_fake = net(fake_sext) - - if self.opt['gan_type'] in ['gan', 'pixgan']: - self.metrics.append(("d_fake", torch.mean(d_fake))) - self.metrics.append(("d_real", torch.mean(d_real))) - l_real = self.criterion(d_real, True) - l_fake = self.criterion(d_fake, False) - l_total = l_real + l_fake - elif self.opt['gan_type'] == 'ragan': - d_fake_diff = d_fake - torch.mean(d_real) - self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff))) - l_total = (self.criterion(d_real - torch.mean(d_fake), True) + - self.criterion(d_fake_diff, False)) - else: - raise NotImplementedError - - l_total = l_total / backwards_count - if self.env['amp']: - with amp.scale_loss(l_total, self.env['current_step_optimizers'][0], self.env['amp_loss_id']) as loss: - loss.backward() - else: - l_total.backward() \ No newline at end of file diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py new file mode 100644 index 00000000..40d66f9b --- /dev/null +++ b/codes/models/steps/tecogan_losses.py @@ -0,0 +1,162 @@ +from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state +from models.layers.resample2d_package.resample2d import Resample2d +from models.steps.recurrent import RecurrentController +from models.steps.injectors import Injector +import torch + +def create_teco_loss(opt, env): + type = opt['type'] + if type == 'teco_generator_gan': + return TecoGanGeneratorLoss(opt, env) + elif type == 'teco_discriminator_gan': + return TecoGanDiscriminatorLoss(opt, env) + elif type == "teco_pingpong": + return PingPongLoss(opt, env) + return None + +def create_teco_discriminator_sextuplet(input_list, index, flow_gen, resampler): + triplet = input_list[index:index+3] + first_flow = flow_gen(triplet[0], triplet[1]) + last_flow = flow_gen(triplet[2], triplet[1]) + flow_triplet = [resampler(triplet[0], first_flow), triplet[1], resampler(triplet[2], last_flow)] + return torch.cat(triplet + flow_triplet, dim=1) + + +# Uses a generator to synthesize a sequence of images from [in] and injects the results into a list [out] +# Images are fed in sequentially forward and back, resulting in len([out])=2*len([in])-1 (last element is not repeated). +# All computation is done with torch.no_grad(). +class RecurrentImageGeneratorSequenceInjector(Injector): + def __init__(self, opt, env): + super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env) + self.flow = opt['flow_network'] + self.resample = Resample2d() + + def forward(self, state): + gen = self.env['generators'][self.opt['generator']] + flow = self.env['generators'][self.flow] + results = [] + recurrent_input = torch.zeros_like(state[self.input][0]) + + # Go forward in the sequence first. + first_step = True + for input in state[self.input]: + if first_step: + first_step = False + else: + flowfield = flow(recurrent_input, input) + recurrent_input = self.resample(recurrent_input, flowfield) + recurrent_input = gen(input, recurrent_input) + results.append(recurrent_input) + recurrent_input = self.flow() + + # Now go backwards, skipping the last element (it's already stored in recurrent_input) + it = reversed(range(len(results) - 1)) + for i in it: + flowfield = flow(recurrent_input, results[i]) + recurrent_input = self.resample(recurrent_input, flowfield) + recurrent_input = gen(results[i], recurrent_input) + results.append(recurrent_input) + + return {self.output: results} + + +# This is the temporal discriminator loss from TecoGAN. +# +# It has a strict contact for 'real' and 'fake' inputs: +# 'real' - Must be a list of arbitrary images (len>3) drawn from the dataset +# 'fake' - The output of the RecurrentImageGeneratorSequenceInjector for the same set of images. +# +# This loss does the following: +# 1) Picks an image triplet, starting with the first '3' elements in 'real' and 'fake'. +# 2) Uses the image flow generator (specified with 'image_flow_generator') to create detached flow fields for the first and last images in the above sequence. +# 3) Warps the first and last images according to the flow field. +# 4) Composes the three base image and the 2 warped images and middle image into a tensor concatenated at the filter dimension for both real and fake, resulting in a bx18xhxw shape tensor. +# 5) Feeds the catted real and fake image sets into the discriminator, computes a loss, and backward(). +# 6) Repeat from (1) until all triplets from the real sequence have been exhausted. +class TecoGanDiscriminatorLoss(ConfigurableLoss): + def __init__(self, opt, env): + super(TecoGanDiscriminatorLoss, self).__init__(opt, env) + self.opt = opt + self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device']) + self.noise = None if 'noise' not in opt.keys() else opt['noise'] + self.image_flow_generator = opt['image_flow_generator'] + self.resampler = Resample2d() + + def forward(self, net, state): + self.metrics = [] + flow_gen = self.env['generators'][self.image_flow_generator] + real = state[self.opt['real']] + fake = state[self.opt['fake']] + l_total = 0 + for i in range(len(real) - 2): + real_sext = create_teco_discriminator_sextuplet(real, i, flow_gen, self.resampler) + fake_sext = create_teco_discriminator_sextuplet(fake, i, flow_gen, self.resampler) + + d_real = net(real_sext) + d_fake = net(fake_sext) + + if self.opt['gan_type'] in ['gan', 'pixgan']: + self.metrics.append(("d_fake", torch.mean(d_fake))) + self.metrics.append(("d_real", torch.mean(d_real))) + l_real = self.criterion(d_real, True) + l_fake = self.criterion(d_fake, False) + l_total += l_real + l_fake + elif self.opt['gan_type'] == 'ragan': + d_fake_diff = d_fake - torch.mean(d_real) + self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff))) + l_total += (self.criterion(d_real - torch.mean(d_fake), True) + + self.criterion(d_fake_diff, False)) + else: + raise NotImplementedError + return l_total + + +class TecoGanGeneratorLoss(ConfigurableLoss): + def __init__(self, opt, env): + super(TecoGanGeneratorLoss, self).__init__(opt, env) + self.opt = opt + self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device']) + # TecoGAN parameters + self.image_flow_generator = opt['image_flow_generator'] + self.resampler = Resample2d() + + def forward(self, _, state): + flow_gen = self.env['generators'][self.image_flow_generator] + real = state[self.opt['real']] + fake = state[self.opt['fake']] + l_total = 0 + for i in range(len(real) - 2): + real_sext = create_teco_discriminator_sextuplet(real, i, flow_gen, self.resampler) + fake_sext = create_teco_discriminator_sextuplet(fake, i, flow_gen, self.resampler) + d_fake = net(fake_sext) + + if self.opt['gan_type'] in ['gan', 'pixgan']: + self.metrics.append(("d_fake", torch.mean(d_fake))) + l_fake = self.criterion(d_fake, True) + l_total += l_fake + elif self.opt['gan_type'] == 'ragan': + d_real = net(real_sext) + d_fake_diff = d_fake - torch.mean(d_real) + self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff))) + l_total += (self.criterion(d_real - torch.mean(d_fake), False) + + self.criterion(d_fake_diff, True)) + else: + raise NotImplementedError + return l_total + + +# This loss doesn't have a real entry - only fakes are used. +class PingPongLoss(ConfigurableLoss): + def __init__(self, opt, env): + super(PingPongLoss, self).__init__(opt, env) + self.opt = opt + self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device']) + + def forward(self, _, state): + fake = state[self.opt['fake']] + l_total = 0 + for i in range((len(fake) - 1) / 2): + early = fake[i] + late = fake[-i] + l_total += self.criterion(early, late) + return l_total \ No newline at end of file