forked from mrq/DL-Art-School
Add ImagePatchInjector and TranslationalLoss
This commit is contained in:
parent
d8621e611a
commit
31641d7f63
|
@ -2,6 +2,7 @@ import torch.nn
|
||||||
from models.archs.SPSR_arch import ImageGradientNoPadding
|
from models.archs.SPSR_arch import ImageGradientNoPadding
|
||||||
from data.weight_scheduler import get_scheduler_for_opt
|
from data.weight_scheduler import get_scheduler_for_opt
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
import torchvision.utils as utils
|
||||||
#from models.steps.recursive_gen_injectors import ImageFlowInjector
|
#from models.steps.recursive_gen_injectors import ImageFlowInjector
|
||||||
|
|
||||||
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
|
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
|
||||||
|
@ -23,6 +24,8 @@ def create_injector(opt_inject, env):
|
||||||
return InterpolateInjector(opt_inject, env)
|
return InterpolateInjector(opt_inject, env)
|
||||||
elif type == 'imageflow':
|
elif type == 'imageflow':
|
||||||
return ImageFlowInjector(opt_inject, env)
|
return ImageFlowInjector(opt_inject, env)
|
||||||
|
elif type == 'image_patch':
|
||||||
|
return ImagePatchInjector(opt_inject, env)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -138,7 +141,7 @@ class GreyInjector(Injector):
|
||||||
mean = mean.repeat(1, 3, 1, 1)
|
mean = mean.repeat(1, 3, 1, 1)
|
||||||
return {self.opt['out']: mean}
|
return {self.opt['out']: mean}
|
||||||
|
|
||||||
import torchvision.utils as utils
|
|
||||||
class InterpolateInjector(Injector):
|
class InterpolateInjector(Injector):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super(InterpolateInjector, self).__init__(opt, env)
|
super(InterpolateInjector, self).__init__(opt, env)
|
||||||
|
@ -147,3 +150,35 @@ class InterpolateInjector(Injector):
|
||||||
scaled = torch.nn.functional.interpolate(state[self.opt['in']], scale_factor=self.opt['scale_factor'],
|
scaled = torch.nn.functional.interpolate(state[self.opt['in']], scale_factor=self.opt['scale_factor'],
|
||||||
mode=self.opt['mode'])
|
mode=self.opt['mode'])
|
||||||
return {self.opt['out']: scaled}
|
return {self.opt['out']: scaled}
|
||||||
|
|
||||||
|
|
||||||
|
# Extracts four patches from the input image, each a square of 'patch_size'. The input images are taken from each
|
||||||
|
# of the four corners of the image. The intent of this loss is that each patch shares some part of the input, which
|
||||||
|
# can then be used in the translation invariance loss.
|
||||||
|
#
|
||||||
|
# This injector is unique in that it does not only produce the specified output label into state. Instead it produces five
|
||||||
|
# outputs for the specified label, one for each corner of the input as well as the specified output, which is the top left
|
||||||
|
# corner. See the code below to find out how this works.
|
||||||
|
#
|
||||||
|
# Another note: this injector operates differently in eval mode (e.g. when env['training']=False) - in this case, it
|
||||||
|
# simply sets all the output state variables to the input. This is so that you can feed the output of this injector
|
||||||
|
# directly into your generator in training without affecting test performance.
|
||||||
|
class ImagePatchInjector(Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super(ImagePatchInjector, self).__init__(opt, env)
|
||||||
|
self.patch_size = opt['patch_size']
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
im = state[self.opt['in']]
|
||||||
|
if self.env['training']:
|
||||||
|
return { self.opt['out']: im[:, :self.patch_size, :self.patch_size],
|
||||||
|
'%s_top_left' % (self.opt['out'],): im[:, :self.patch_size, :self.patch_size],
|
||||||
|
'%s_top_right' % (self.opt['out'],): im[:, :self.patch_size, -self.patch_size:],
|
||||||
|
'%s_bottom_left' % (self.opt['out'],): im[:, -self.patch_size:, :self.patch_size],
|
||||||
|
'%s_bottom_right' % (self.opt['out'],): im[:, -self.patch_size:, -self.patch_size:] }
|
||||||
|
else:
|
||||||
|
return { self.opt['out']: im,
|
||||||
|
'%s_top_left' % (self.opt['out'],): im,
|
||||||
|
'%s_top_right' % (self.opt['out'],): im,
|
||||||
|
'%s_bottom_left' % (self.opt['out'],): im,
|
||||||
|
'%s_bottom_right' % (self.opt['out'],): im }
|
||||||
|
|
|
@ -2,6 +2,8 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from models.networks import define_F
|
from models.networks import define_F
|
||||||
from models.loss import GANLoss
|
from models.loss import GANLoss
|
||||||
|
import random
|
||||||
|
import functools
|
||||||
|
|
||||||
|
|
||||||
def create_generator_loss(opt_loss, env):
|
def create_generator_loss(opt_loss, env):
|
||||||
|
@ -18,6 +20,8 @@ def create_generator_loss(opt_loss, env):
|
||||||
return DiscriminatorGanLoss(opt_loss, env)
|
return DiscriminatorGanLoss(opt_loss, env)
|
||||||
elif type == 'geometric':
|
elif type == 'geometric':
|
||||||
return GeometricSimilarityGeneratorLoss(opt_loss, env)
|
return GeometricSimilarityGeneratorLoss(opt_loss, env)
|
||||||
|
elif type == 'translational':
|
||||||
|
return TranslationInvarianceLoss(opt_loss, env)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -190,8 +194,6 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
import random
|
|
||||||
import functools
|
|
||||||
|
|
||||||
# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an
|
# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an
|
||||||
# input that has been altered randomly by rotation or flip.
|
# input that has been altered randomly by rotation or flip.
|
||||||
|
@ -239,4 +241,45 @@ class GeometricSimilarityGeneratorLoss(ConfigurableLoss):
|
||||||
# Undo alteration on HR image
|
# Undo alteration on HR image
|
||||||
upsampled_altered = undo_fn(upsampled_altered)
|
upsampled_altered = undo_fn(upsampled_altered)
|
||||||
|
|
||||||
return self.criterion(state[self.opt['real']], upsampled_altered)
|
return self.criterion(state[self.opt['real']], upsampled_altered)
|
||||||
|
|
||||||
|
|
||||||
|
# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an
|
||||||
|
# input that has been translated in a random direction.
|
||||||
|
# The "real" parameter to this loss is the actual output of the generator on the top left image patch.
|
||||||
|
# The "fake" parameter is the output base fed into a ImagePatchInjector.
|
||||||
|
class TranslationInvarianceLoss(ConfigurableLoss):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super(TranslationInvarianceLoss, self).__init__(opt, env)
|
||||||
|
self.opt = opt
|
||||||
|
self.generator = opt['generator']
|
||||||
|
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
||||||
|
self.gen_input_for_alteration = opt['input_alteration_index'] if 'input_alteration_index' in opt.keys() else 0
|
||||||
|
self.gen_output_to_use = opt['generator_output_index'] if 'generator_output_index' in opt.keys() else None
|
||||||
|
self.patch_size = opt['patch_size']
|
||||||
|
self.overlap = opt['overlap'] # For maximum overlap, can be calculated as 2*patch_size-image_size
|
||||||
|
assert(self.patch_size > self.overlap)
|
||||||
|
|
||||||
|
def forward(self, net, state):
|
||||||
|
self.metrics = []
|
||||||
|
net = self.env['generators'][self.generator] # Get the network from an explicit parameter.
|
||||||
|
# The <net> parameter is not reliable for generator losses since often they are combined with many networks.
|
||||||
|
|
||||||
|
border_sz = self.patch_size - self.overlap
|
||||||
|
translation = random.choice([("top_right", border_sz, border_sz+self.overlap, 0, self.overlap),
|
||||||
|
("bottom_left", 0, self.overlap, border_sz, border_sz+self.overlap),
|
||||||
|
("bottom_right", 0, self.overlap, 0, self.overlap)])
|
||||||
|
trans_name, hl, hh, wl, wh = translation
|
||||||
|
# Change the "fake" input name that we are translating to one that specifies the random translation.
|
||||||
|
self.opt['fake'][self.gen_input_for_alteration] = "%s_%s" % (self.opt['fake'], trans_name)
|
||||||
|
input = extract_params_from_state(self.opt['fake'], state)
|
||||||
|
with torch.no_grad():
|
||||||
|
trans_output = net(*input)
|
||||||
|
fake_shared_output = trans_output[:, hl:hh, wl:wh][self.gen_output_to_use]
|
||||||
|
|
||||||
|
# The "real" input is assumed to always come from the top left tile.
|
||||||
|
gen_output = state[self.opt['real']]
|
||||||
|
real_shared_output = gen_output[:, border_sz:border_sz+self.overlap, border_sz:border_sz+self.overlap][self.gen_output_to_use]
|
||||||
|
|
||||||
|
return self.criterion(fake_shared_output, real_shared_output)
|
||||||
|
|
||||||
|
|
|
@ -112,6 +112,7 @@ class ConfigurableStep(Module):
|
||||||
# Some losses compute backward() internally. Accomodate this by stashing the amp_loss_id in env.
|
# Some losses compute backward() internally. Accomodate this by stashing the amp_loss_id in env.
|
||||||
self.env['amp_loss_id'] = amp_loss_id
|
self.env['amp_loss_id'] = amp_loss_id
|
||||||
self.env['current_step_optimizers'] = self.optimizers
|
self.env['current_step_optimizers'] = self.optimizers
|
||||||
|
self.env['training'] = train
|
||||||
|
|
||||||
# Inject in any extra dependencies.
|
# Inject in any extra dependencies.
|
||||||
for inj in self.injectors:
|
for inj in self.injectors:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user