From 17c569ea6236cf411ad97592cb2c6cb31b585448 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 20 Sep 2020 16:24:23 -0600 Subject: [PATCH] Add geometric loss --- codes/models/steps/injectors.py | 15 ++---- codes/models/steps/losses.py | 53 +++++++++++++++++++ codes/models/steps/recursive_gen_injectors.py | 33 ++++++++++++ 3 files changed, 89 insertions(+), 12 deletions(-) create mode 100644 codes/models/steps/recursive_gen_injectors.py diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index 7ba02420..a136e139 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -1,6 +1,8 @@ import torch.nn from models.archs.SPSR_arch import ImageGradientNoPadding from data.weight_scheduler import get_scheduler_for_opt +from torch.utils.checkpoint import checkpoint +#from models.steps.recursive_gen_injectors import ImageFlowInjector # Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions. def create_injector(opt_inject, env): @@ -136,7 +138,7 @@ class GreyInjector(Injector): mean = mean.repeat(1, 3, 1, 1) return {self.opt['out']: mean} - +import torchvision.utils as utils class InterpolateInjector(Injector): def __init__(self, opt, env): super(InterpolateInjector, self).__init__(opt, env) @@ -145,14 +147,3 @@ class InterpolateInjector(Injector): scaled = torch.nn.functional.interpolate(state[self.opt['in']], scale_factor=self.opt['scale_factor'], mode=self.opt['mode']) return {self.opt['out']: scaled} - - -class ImageFlowInjector(Injector): - def __init__(self, opt, env): - # Requires building this custom cuda kernel. Only require it if explicitly needed. - from models.networks.layers.resample2d_package.resample2d import Resample2d - super(ImageFlowInjector, self).__init__(opt, env) - self.resample = Resample2d() - - def forward(self, state): - return self.resample(state[self.opt['in']], state[self.opt['flow']]) diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 73fe38b3..d367d318 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -16,6 +16,8 @@ def create_generator_loss(opt_loss, env): return GeneratorGanLoss(opt_loss, env) elif type == 'discriminator_gan': return DiscriminatorGanLoss(opt_loss, env) + elif type == 'geometric': + return GeometricSimilarityGeneratorLoss(opt_loss, env) else: raise NotImplementedError @@ -123,6 +125,7 @@ class GeneratorGanLoss(ConfigurableLoss): else: raise NotImplementedError +import torchvision class DiscriminatorGanLoss(ConfigurableLoss): def __init__(self, opt, env): @@ -165,3 +168,53 @@ class DiscriminatorGanLoss(ConfigurableLoss): else: raise NotImplementedError +import random +import functools + +# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an +# input that has been altered randomly by rotation or flip. +# The "real" parameter to this loss is the actual output of the generator (from an injection point) +# The "fake" parameter is the LR input that produced the "real" parameter when fed through the generator. +class GeometricSimilarityGeneratorLoss(ConfigurableLoss): + def __init__(self, opt, env): + super(GeometricSimilarityGeneratorLoss, self).__init__(opt, env) + self.opt = opt + self.generator = opt['generator'] + self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device']) + self.gen_input_for_alteration = opt['input_alteration_index'] if 'input_alteration_index' in opt.keys() else 0 + self.gen_output_to_use = opt['generator_output_index'] if 'generator_output_index' in opt.keys() else None + self.detach_fake = opt['detach_fake'] if 'detach_fake' in opt.keys() else False + + # Returns a random alteration and its counterpart (that undoes the alteration) + def random_alteration(self): + return random.choice([(functools.partial(torch.flip, dims=(2,)), functools.partial(torch.flip, dims=(2,))), + (functools.partial(torch.flip, dims=(3,)), functools.partial(torch.flip, dims=(3,))), + (functools.partial(torch.rot90, k=1, dims=[2,3]), functools.partial(torch.rot90, k=3, dims=[2,3])), + (functools.partial(torch.rot90, k=2, dims=[2,3]), functools.partial(torch.rot90, k=2, dims=[2,3])), + (functools.partial(torch.rot90, k=3, dims=[2,3]), functools.partial(torch.rot90, k=1, dims=[2,3]))]) + + def forward(self, net, state): + self.metrics = [] + net = self.env['generators'][self.generator] # Get the network from an explicit parameter. + # The parameter is not reliable for generator losses since often they are combined with many networks. + fake = extract_params_from_state(self.opt['fake'], state) + alteration, undo_fn = self.random_alteration() + altered = [] + for i, t in enumerate(fake): + if i == self.gen_input_for_alteration: + altered.append(alteration(t)) + else: + altered.append(t) + if self.detach_fake: + with torch.no_grad(): + upsampled_altered = net(*altered) + else: + upsampled_altered = net(*altered) + + if self.gen_output_to_use: + upsampled_altered = upsampled_altered[self.gen_output_to_use] + + # Undo alteration on HR image + upsampled_altered = undo_fn(upsampled_altered) + + return self.criterion(state[self.opt['real']], upsampled_altered) \ No newline at end of file diff --git a/codes/models/steps/recursive_gen_injectors.py b/codes/models/steps/recursive_gen_injectors.py new file mode 100644 index 00000000..6b97ff3b --- /dev/null +++ b/codes/models/steps/recursive_gen_injectors.py @@ -0,0 +1,33 @@ +import models.steps.injectors as injectors + + +# Uses a generator to synthesize a sequence of images from [in] and injects the results into a list [out] +# All results are checkpointed for memory savings. Recurrent inputs are also detached before being fed back into +# the generator. +class RecurrentImageGeneratorSequenceInjector(injectors.Injector): + def __init__(self, opt, env): + super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env) + + def forward(self, state): + gen = self.env['generators'][self.opt['generator']] + new_state = {} + results = [] + recurrent_input = torch.zeros_like(state[self.input][0]) + for input in state[self.input]: + result = checkpoint(gen, input, recurrent_input) + results.append(result) + recurrent_input = result.detach() + + new_state = {self.output: results} + return new_state + + +class ImageFlowInjector(injectors.Injector): + def __init__(self, opt, env): + # Requires building this custom cuda kernel. Only require it if explicitly needed. + from models.networks.layers.resample2d_package.resample2d import Resample2d + super(ImageFlowInjector, self).__init__(opt, env) + self.resample = Resample2d() + + def forward(self, state): + return self.resample(state[self.opt['in']], state[self.opt['flow']])