Add ImagePatchInjector and TranslationalLoss

This commit is contained in:
James Betker 2020-09-26 21:25:32 -06:00
parent d8621e611a
commit 31641d7f63
3 changed files with 83 additions and 4 deletions

View File

@ -2,6 +2,7 @@ import torch.nn
from models.archs.SPSR_arch import ImageGradientNoPadding
from data.weight_scheduler import get_scheduler_for_opt
from torch.utils.checkpoint import checkpoint
import torchvision.utils as utils
#from models.steps.recursive_gen_injectors import ImageFlowInjector
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
@ -23,6 +24,8 @@ def create_injector(opt_inject, env):
return InterpolateInjector(opt_inject, env)
elif type == 'imageflow':
return ImageFlowInjector(opt_inject, env)
elif type == 'image_patch':
return ImagePatchInjector(opt_inject, env)
else:
raise NotImplementedError
@ -138,7 +141,7 @@ class GreyInjector(Injector):
mean = mean.repeat(1, 3, 1, 1)
return {self.opt['out']: mean}
import torchvision.utils as utils
class InterpolateInjector(Injector):
def __init__(self, opt, env):
super(InterpolateInjector, self).__init__(opt, env)
@ -147,3 +150,35 @@ class InterpolateInjector(Injector):
scaled = torch.nn.functional.interpolate(state[self.opt['in']], scale_factor=self.opt['scale_factor'],
mode=self.opt['mode'])
return {self.opt['out']: scaled}
# Extracts four patches from the input image, each a square of 'patch_size'. The input images are taken from each
# of the four corners of the image. The intent of this loss is that each patch shares some part of the input, which
# can then be used in the translation invariance loss.
#
# This injector is unique in that it does not only produce the specified output label into state. Instead it produces five
# outputs for the specified label, one for each corner of the input as well as the specified output, which is the top left
# corner. See the code below to find out how this works.
#
# Another note: this injector operates differently in eval mode (e.g. when env['training']=False) - in this case, it
# simply sets all the output state variables to the input. This is so that you can feed the output of this injector
# directly into your generator in training without affecting test performance.
class ImagePatchInjector(Injector):
def __init__(self, opt, env):
super(ImagePatchInjector, self).__init__(opt, env)
self.patch_size = opt['patch_size']
def forward(self, state):
im = state[self.opt['in']]
if self.env['training']:
return { self.opt['out']: im[:, :self.patch_size, :self.patch_size],
'%s_top_left' % (self.opt['out'],): im[:, :self.patch_size, :self.patch_size],
'%s_top_right' % (self.opt['out'],): im[:, :self.patch_size, -self.patch_size:],
'%s_bottom_left' % (self.opt['out'],): im[:, -self.patch_size:, :self.patch_size],
'%s_bottom_right' % (self.opt['out'],): im[:, -self.patch_size:, -self.patch_size:] }
else:
return { self.opt['out']: im,
'%s_top_left' % (self.opt['out'],): im,
'%s_top_right' % (self.opt['out'],): im,
'%s_bottom_left' % (self.opt['out'],): im,
'%s_bottom_right' % (self.opt['out'],): im }

View File

@ -2,6 +2,8 @@ import torch
import torch.nn as nn
from models.networks import define_F
from models.loss import GANLoss
import random
import functools
def create_generator_loss(opt_loss, env):
@ -18,6 +20,8 @@ def create_generator_loss(opt_loss, env):
return DiscriminatorGanLoss(opt_loss, env)
elif type == 'geometric':
return GeometricSimilarityGeneratorLoss(opt_loss, env)
elif type == 'translational':
return TranslationInvarianceLoss(opt_loss, env)
else:
raise NotImplementedError
@ -190,8 +194,6 @@ class DiscriminatorGanLoss(ConfigurableLoss):
else:
raise NotImplementedError
import random
import functools
# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an
# input that has been altered randomly by rotation or flip.
@ -239,4 +241,45 @@ class GeometricSimilarityGeneratorLoss(ConfigurableLoss):
# Undo alteration on HR image
upsampled_altered = undo_fn(upsampled_altered)
return self.criterion(state[self.opt['real']], upsampled_altered)
return self.criterion(state[self.opt['real']], upsampled_altered)
# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an
# input that has been translated in a random direction.
# The "real" parameter to this loss is the actual output of the generator on the top left image patch.
# The "fake" parameter is the output base fed into a ImagePatchInjector.
class TranslationInvarianceLoss(ConfigurableLoss):
def __init__(self, opt, env):
super(TranslationInvarianceLoss, self).__init__(opt, env)
self.opt = opt
self.generator = opt['generator']
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
self.gen_input_for_alteration = opt['input_alteration_index'] if 'input_alteration_index' in opt.keys() else 0
self.gen_output_to_use = opt['generator_output_index'] if 'generator_output_index' in opt.keys() else None
self.patch_size = opt['patch_size']
self.overlap = opt['overlap'] # For maximum overlap, can be calculated as 2*patch_size-image_size
assert(self.patch_size > self.overlap)
def forward(self, net, state):
self.metrics = []
net = self.env['generators'][self.generator] # Get the network from an explicit parameter.
# The <net> parameter is not reliable for generator losses since often they are combined with many networks.
border_sz = self.patch_size - self.overlap
translation = random.choice([("top_right", border_sz, border_sz+self.overlap, 0, self.overlap),
("bottom_left", 0, self.overlap, border_sz, border_sz+self.overlap),
("bottom_right", 0, self.overlap, 0, self.overlap)])
trans_name, hl, hh, wl, wh = translation
# Change the "fake" input name that we are translating to one that specifies the random translation.
self.opt['fake'][self.gen_input_for_alteration] = "%s_%s" % (self.opt['fake'], trans_name)
input = extract_params_from_state(self.opt['fake'], state)
with torch.no_grad():
trans_output = net(*input)
fake_shared_output = trans_output[:, hl:hh, wl:wh][self.gen_output_to_use]
# The "real" input is assumed to always come from the top left tile.
gen_output = state[self.opt['real']]
real_shared_output = gen_output[:, border_sz:border_sz+self.overlap, border_sz:border_sz+self.overlap][self.gen_output_to_use]
return self.criterion(fake_shared_output, real_shared_output)

View File

@ -112,6 +112,7 @@ class ConfigurableStep(Module):
# Some losses compute backward() internally. Accomodate this by stashing the amp_loss_id in env.
self.env['amp_loss_id'] = amp_loss_id
self.env['current_step_optimizers'] = self.optimizers
self.env['training'] = train
# Inject in any extra dependencies.
for inj in self.injectors: