diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index a136e139..74ad5a72 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -2,6 +2,7 @@ import torch.nn from models.archs.SPSR_arch import ImageGradientNoPadding from data.weight_scheduler import get_scheduler_for_opt from torch.utils.checkpoint import checkpoint +import torchvision.utils as utils #from models.steps.recursive_gen_injectors import ImageFlowInjector # Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions. @@ -23,6 +24,8 @@ def create_injector(opt_inject, env): return InterpolateInjector(opt_inject, env) elif type == 'imageflow': return ImageFlowInjector(opt_inject, env) + elif type == 'image_patch': + return ImagePatchInjector(opt_inject, env) else: raise NotImplementedError @@ -138,7 +141,7 @@ class GreyInjector(Injector): mean = mean.repeat(1, 3, 1, 1) return {self.opt['out']: mean} -import torchvision.utils as utils + class InterpolateInjector(Injector): def __init__(self, opt, env): super(InterpolateInjector, self).__init__(opt, env) @@ -147,3 +150,35 @@ class InterpolateInjector(Injector): scaled = torch.nn.functional.interpolate(state[self.opt['in']], scale_factor=self.opt['scale_factor'], mode=self.opt['mode']) return {self.opt['out']: scaled} + + +# Extracts four patches from the input image, each a square of 'patch_size'. The input images are taken from each +# of the four corners of the image. The intent of this loss is that each patch shares some part of the input, which +# can then be used in the translation invariance loss. +# +# This injector is unique in that it does not only produce the specified output label into state. Instead it produces five +# outputs for the specified label, one for each corner of the input as well as the specified output, which is the top left +# corner. See the code below to find out how this works. +# +# Another note: this injector operates differently in eval mode (e.g. when env['training']=False) - in this case, it +# simply sets all the output state variables to the input. This is so that you can feed the output of this injector +# directly into your generator in training without affecting test performance. +class ImagePatchInjector(Injector): + def __init__(self, opt, env): + super(ImagePatchInjector, self).__init__(opt, env) + self.patch_size = opt['patch_size'] + + def forward(self, state): + im = state[self.opt['in']] + if self.env['training']: + return { self.opt['out']: im[:, :self.patch_size, :self.patch_size], + '%s_top_left' % (self.opt['out'],): im[:, :self.patch_size, :self.patch_size], + '%s_top_right' % (self.opt['out'],): im[:, :self.patch_size, -self.patch_size:], + '%s_bottom_left' % (self.opt['out'],): im[:, -self.patch_size:, :self.patch_size], + '%s_bottom_right' % (self.opt['out'],): im[:, -self.patch_size:, -self.patch_size:] } + else: + return { self.opt['out']: im, + '%s_top_left' % (self.opt['out'],): im, + '%s_top_right' % (self.opt['out'],): im, + '%s_bottom_left' % (self.opt['out'],): im, + '%s_bottom_right' % (self.opt['out'],): im } diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 1953c35c..da4ed187 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -2,6 +2,8 @@ import torch import torch.nn as nn from models.networks import define_F from models.loss import GANLoss +import random +import functools def create_generator_loss(opt_loss, env): @@ -18,6 +20,8 @@ def create_generator_loss(opt_loss, env): return DiscriminatorGanLoss(opt_loss, env) elif type == 'geometric': return GeometricSimilarityGeneratorLoss(opt_loss, env) + elif type == 'translational': + return TranslationInvarianceLoss(opt_loss, env) else: raise NotImplementedError @@ -190,8 +194,6 @@ class DiscriminatorGanLoss(ConfigurableLoss): else: raise NotImplementedError -import random -import functools # Computes a loss created by comparing the output of a generator to the output from the same generator when fed an # input that has been altered randomly by rotation or flip. @@ -239,4 +241,45 @@ class GeometricSimilarityGeneratorLoss(ConfigurableLoss): # Undo alteration on HR image upsampled_altered = undo_fn(upsampled_altered) - return self.criterion(state[self.opt['real']], upsampled_altered) \ No newline at end of file + return self.criterion(state[self.opt['real']], upsampled_altered) + + +# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an +# input that has been translated in a random direction. +# The "real" parameter to this loss is the actual output of the generator on the top left image patch. +# The "fake" parameter is the output base fed into a ImagePatchInjector. +class TranslationInvarianceLoss(ConfigurableLoss): + def __init__(self, opt, env): + super(TranslationInvarianceLoss, self).__init__(opt, env) + self.opt = opt + self.generator = opt['generator'] + self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device']) + self.gen_input_for_alteration = opt['input_alteration_index'] if 'input_alteration_index' in opt.keys() else 0 + self.gen_output_to_use = opt['generator_output_index'] if 'generator_output_index' in opt.keys() else None + self.patch_size = opt['patch_size'] + self.overlap = opt['overlap'] # For maximum overlap, can be calculated as 2*patch_size-image_size + assert(self.patch_size > self.overlap) + + def forward(self, net, state): + self.metrics = [] + net = self.env['generators'][self.generator] # Get the network from an explicit parameter. + # The parameter is not reliable for generator losses since often they are combined with many networks. + + border_sz = self.patch_size - self.overlap + translation = random.choice([("top_right", border_sz, border_sz+self.overlap, 0, self.overlap), + ("bottom_left", 0, self.overlap, border_sz, border_sz+self.overlap), + ("bottom_right", 0, self.overlap, 0, self.overlap)]) + trans_name, hl, hh, wl, wh = translation + # Change the "fake" input name that we are translating to one that specifies the random translation. + self.opt['fake'][self.gen_input_for_alteration] = "%s_%s" % (self.opt['fake'], trans_name) + input = extract_params_from_state(self.opt['fake'], state) + with torch.no_grad(): + trans_output = net(*input) + fake_shared_output = trans_output[:, hl:hh, wl:wh][self.gen_output_to_use] + + # The "real" input is assumed to always come from the top left tile. + gen_output = state[self.opt['real']] + real_shared_output = gen_output[:, border_sz:border_sz+self.overlap, border_sz:border_sz+self.overlap][self.gen_output_to_use] + + return self.criterion(fake_shared_output, real_shared_output) + diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index 72086f47..4037fb09 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -112,6 +112,7 @@ class ConfigurableStep(Module): # Some losses compute backward() internally. Accomodate this by stashing the amp_loss_id in env. self.env['amp_loss_id'] = amp_loss_id self.env['current_step_optimizers'] = self.optimizers + self.env['training'] = train # Inject in any extra dependencies. for inj in self.injectors: