Add ImagePatchInjector and TranslationalLoss
This commit is contained in:
parent
d8621e611a
commit
31641d7f63
|
@ -2,6 +2,7 @@ import torch.nn
|
|||
from models.archs.SPSR_arch import ImageGradientNoPadding
|
||||
from data.weight_scheduler import get_scheduler_for_opt
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import torchvision.utils as utils
|
||||
#from models.steps.recursive_gen_injectors import ImageFlowInjector
|
||||
|
||||
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
|
||||
|
@ -23,6 +24,8 @@ def create_injector(opt_inject, env):
|
|||
return InterpolateInjector(opt_inject, env)
|
||||
elif type == 'imageflow':
|
||||
return ImageFlowInjector(opt_inject, env)
|
||||
elif type == 'image_patch':
|
||||
return ImagePatchInjector(opt_inject, env)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -138,7 +141,7 @@ class GreyInjector(Injector):
|
|||
mean = mean.repeat(1, 3, 1, 1)
|
||||
return {self.opt['out']: mean}
|
||||
|
||||
import torchvision.utils as utils
|
||||
|
||||
class InterpolateInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(InterpolateInjector, self).__init__(opt, env)
|
||||
|
@ -147,3 +150,35 @@ class InterpolateInjector(Injector):
|
|||
scaled = torch.nn.functional.interpolate(state[self.opt['in']], scale_factor=self.opt['scale_factor'],
|
||||
mode=self.opt['mode'])
|
||||
return {self.opt['out']: scaled}
|
||||
|
||||
|
||||
# Extracts four patches from the input image, each a square of 'patch_size'. The input images are taken from each
|
||||
# of the four corners of the image. The intent of this loss is that each patch shares some part of the input, which
|
||||
# can then be used in the translation invariance loss.
|
||||
#
|
||||
# This injector is unique in that it does not only produce the specified output label into state. Instead it produces five
|
||||
# outputs for the specified label, one for each corner of the input as well as the specified output, which is the top left
|
||||
# corner. See the code below to find out how this works.
|
||||
#
|
||||
# Another note: this injector operates differently in eval mode (e.g. when env['training']=False) - in this case, it
|
||||
# simply sets all the output state variables to the input. This is so that you can feed the output of this injector
|
||||
# directly into your generator in training without affecting test performance.
|
||||
class ImagePatchInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(ImagePatchInjector, self).__init__(opt, env)
|
||||
self.patch_size = opt['patch_size']
|
||||
|
||||
def forward(self, state):
|
||||
im = state[self.opt['in']]
|
||||
if self.env['training']:
|
||||
return { self.opt['out']: im[:, :self.patch_size, :self.patch_size],
|
||||
'%s_top_left' % (self.opt['out'],): im[:, :self.patch_size, :self.patch_size],
|
||||
'%s_top_right' % (self.opt['out'],): im[:, :self.patch_size, -self.patch_size:],
|
||||
'%s_bottom_left' % (self.opt['out'],): im[:, -self.patch_size:, :self.patch_size],
|
||||
'%s_bottom_right' % (self.opt['out'],): im[:, -self.patch_size:, -self.patch_size:] }
|
||||
else:
|
||||
return { self.opt['out']: im,
|
||||
'%s_top_left' % (self.opt['out'],): im,
|
||||
'%s_top_right' % (self.opt['out'],): im,
|
||||
'%s_bottom_left' % (self.opt['out'],): im,
|
||||
'%s_bottom_right' % (self.opt['out'],): im }
|
||||
|
|
|
@ -2,6 +2,8 @@ import torch
|
|||
import torch.nn as nn
|
||||
from models.networks import define_F
|
||||
from models.loss import GANLoss
|
||||
import random
|
||||
import functools
|
||||
|
||||
|
||||
def create_generator_loss(opt_loss, env):
|
||||
|
@ -18,6 +20,8 @@ def create_generator_loss(opt_loss, env):
|
|||
return DiscriminatorGanLoss(opt_loss, env)
|
||||
elif type == 'geometric':
|
||||
return GeometricSimilarityGeneratorLoss(opt_loss, env)
|
||||
elif type == 'translational':
|
||||
return TranslationInvarianceLoss(opt_loss, env)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -190,8 +194,6 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
import random
|
||||
import functools
|
||||
|
||||
# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an
|
||||
# input that has been altered randomly by rotation or flip.
|
||||
|
@ -239,4 +241,45 @@ class GeometricSimilarityGeneratorLoss(ConfigurableLoss):
|
|||
# Undo alteration on HR image
|
||||
upsampled_altered = undo_fn(upsampled_altered)
|
||||
|
||||
return self.criterion(state[self.opt['real']], upsampled_altered)
|
||||
return self.criterion(state[self.opt['real']], upsampled_altered)
|
||||
|
||||
|
||||
# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an
|
||||
# input that has been translated in a random direction.
|
||||
# The "real" parameter to this loss is the actual output of the generator on the top left image patch.
|
||||
# The "fake" parameter is the output base fed into a ImagePatchInjector.
|
||||
class TranslationInvarianceLoss(ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
super(TranslationInvarianceLoss, self).__init__(opt, env)
|
||||
self.opt = opt
|
||||
self.generator = opt['generator']
|
||||
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
||||
self.gen_input_for_alteration = opt['input_alteration_index'] if 'input_alteration_index' in opt.keys() else 0
|
||||
self.gen_output_to_use = opt['generator_output_index'] if 'generator_output_index' in opt.keys() else None
|
||||
self.patch_size = opt['patch_size']
|
||||
self.overlap = opt['overlap'] # For maximum overlap, can be calculated as 2*patch_size-image_size
|
||||
assert(self.patch_size > self.overlap)
|
||||
|
||||
def forward(self, net, state):
|
||||
self.metrics = []
|
||||
net = self.env['generators'][self.generator] # Get the network from an explicit parameter.
|
||||
# The <net> parameter is not reliable for generator losses since often they are combined with many networks.
|
||||
|
||||
border_sz = self.patch_size - self.overlap
|
||||
translation = random.choice([("top_right", border_sz, border_sz+self.overlap, 0, self.overlap),
|
||||
("bottom_left", 0, self.overlap, border_sz, border_sz+self.overlap),
|
||||
("bottom_right", 0, self.overlap, 0, self.overlap)])
|
||||
trans_name, hl, hh, wl, wh = translation
|
||||
# Change the "fake" input name that we are translating to one that specifies the random translation.
|
||||
self.opt['fake'][self.gen_input_for_alteration] = "%s_%s" % (self.opt['fake'], trans_name)
|
||||
input = extract_params_from_state(self.opt['fake'], state)
|
||||
with torch.no_grad():
|
||||
trans_output = net(*input)
|
||||
fake_shared_output = trans_output[:, hl:hh, wl:wh][self.gen_output_to_use]
|
||||
|
||||
# The "real" input is assumed to always come from the top left tile.
|
||||
gen_output = state[self.opt['real']]
|
||||
real_shared_output = gen_output[:, border_sz:border_sz+self.overlap, border_sz:border_sz+self.overlap][self.gen_output_to_use]
|
||||
|
||||
return self.criterion(fake_shared_output, real_shared_output)
|
||||
|
||||
|
|
|
@ -112,6 +112,7 @@ class ConfigurableStep(Module):
|
|||
# Some losses compute backward() internally. Accomodate this by stashing the amp_loss_id in env.
|
||||
self.env['amp_loss_id'] = amp_loss_id
|
||||
self.env['current_step_optimizers'] = self.optimizers
|
||||
self.env['training'] = train
|
||||
|
||||
# Inject in any extra dependencies.
|
||||
for inj in self.injectors:
|
||||
|
|
Loading…
Reference in New Issue
Block a user