forked from mrq/DL-Art-School
Add geometric loss
This commit is contained in:
parent
17dd99b29b
commit
17c569ea62
|
@ -1,6 +1,8 @@
|
||||||
import torch.nn
|
import torch.nn
|
||||||
from models.archs.SPSR_arch import ImageGradientNoPadding
|
from models.archs.SPSR_arch import ImageGradientNoPadding
|
||||||
from data.weight_scheduler import get_scheduler_for_opt
|
from data.weight_scheduler import get_scheduler_for_opt
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
#from models.steps.recursive_gen_injectors import ImageFlowInjector
|
||||||
|
|
||||||
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
|
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
|
||||||
def create_injector(opt_inject, env):
|
def create_injector(opt_inject, env):
|
||||||
|
@ -136,7 +138,7 @@ class GreyInjector(Injector):
|
||||||
mean = mean.repeat(1, 3, 1, 1)
|
mean = mean.repeat(1, 3, 1, 1)
|
||||||
return {self.opt['out']: mean}
|
return {self.opt['out']: mean}
|
||||||
|
|
||||||
|
import torchvision.utils as utils
|
||||||
class InterpolateInjector(Injector):
|
class InterpolateInjector(Injector):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super(InterpolateInjector, self).__init__(opt, env)
|
super(InterpolateInjector, self).__init__(opt, env)
|
||||||
|
@ -145,14 +147,3 @@ class InterpolateInjector(Injector):
|
||||||
scaled = torch.nn.functional.interpolate(state[self.opt['in']], scale_factor=self.opt['scale_factor'],
|
scaled = torch.nn.functional.interpolate(state[self.opt['in']], scale_factor=self.opt['scale_factor'],
|
||||||
mode=self.opt['mode'])
|
mode=self.opt['mode'])
|
||||||
return {self.opt['out']: scaled}
|
return {self.opt['out']: scaled}
|
||||||
|
|
||||||
|
|
||||||
class ImageFlowInjector(Injector):
|
|
||||||
def __init__(self, opt, env):
|
|
||||||
# Requires building this custom cuda kernel. Only require it if explicitly needed.
|
|
||||||
from models.networks.layers.resample2d_package.resample2d import Resample2d
|
|
||||||
super(ImageFlowInjector, self).__init__(opt, env)
|
|
||||||
self.resample = Resample2d()
|
|
||||||
|
|
||||||
def forward(self, state):
|
|
||||||
return self.resample(state[self.opt['in']], state[self.opt['flow']])
|
|
||||||
|
|
|
@ -16,6 +16,8 @@ def create_generator_loss(opt_loss, env):
|
||||||
return GeneratorGanLoss(opt_loss, env)
|
return GeneratorGanLoss(opt_loss, env)
|
||||||
elif type == 'discriminator_gan':
|
elif type == 'discriminator_gan':
|
||||||
return DiscriminatorGanLoss(opt_loss, env)
|
return DiscriminatorGanLoss(opt_loss, env)
|
||||||
|
elif type == 'geometric':
|
||||||
|
return GeometricSimilarityGeneratorLoss(opt_loss, env)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -123,6 +125,7 @@ class GeneratorGanLoss(ConfigurableLoss):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
import torchvision
|
||||||
|
|
||||||
class DiscriminatorGanLoss(ConfigurableLoss):
|
class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
|
@ -165,3 +168,53 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
import random
|
||||||
|
import functools
|
||||||
|
|
||||||
|
# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an
|
||||||
|
# input that has been altered randomly by rotation or flip.
|
||||||
|
# The "real" parameter to this loss is the actual output of the generator (from an injection point)
|
||||||
|
# The "fake" parameter is the LR input that produced the "real" parameter when fed through the generator.
|
||||||
|
class GeometricSimilarityGeneratorLoss(ConfigurableLoss):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super(GeometricSimilarityGeneratorLoss, self).__init__(opt, env)
|
||||||
|
self.opt = opt
|
||||||
|
self.generator = opt['generator']
|
||||||
|
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
||||||
|
self.gen_input_for_alteration = opt['input_alteration_index'] if 'input_alteration_index' in opt.keys() else 0
|
||||||
|
self.gen_output_to_use = opt['generator_output_index'] if 'generator_output_index' in opt.keys() else None
|
||||||
|
self.detach_fake = opt['detach_fake'] if 'detach_fake' in opt.keys() else False
|
||||||
|
|
||||||
|
# Returns a random alteration and its counterpart (that undoes the alteration)
|
||||||
|
def random_alteration(self):
|
||||||
|
return random.choice([(functools.partial(torch.flip, dims=(2,)), functools.partial(torch.flip, dims=(2,))),
|
||||||
|
(functools.partial(torch.flip, dims=(3,)), functools.partial(torch.flip, dims=(3,))),
|
||||||
|
(functools.partial(torch.rot90, k=1, dims=[2,3]), functools.partial(torch.rot90, k=3, dims=[2,3])),
|
||||||
|
(functools.partial(torch.rot90, k=2, dims=[2,3]), functools.partial(torch.rot90, k=2, dims=[2,3])),
|
||||||
|
(functools.partial(torch.rot90, k=3, dims=[2,3]), functools.partial(torch.rot90, k=1, dims=[2,3]))])
|
||||||
|
|
||||||
|
def forward(self, net, state):
|
||||||
|
self.metrics = []
|
||||||
|
net = self.env['generators'][self.generator] # Get the network from an explicit parameter.
|
||||||
|
# The <net> parameter is not reliable for generator losses since often they are combined with many networks.
|
||||||
|
fake = extract_params_from_state(self.opt['fake'], state)
|
||||||
|
alteration, undo_fn = self.random_alteration()
|
||||||
|
altered = []
|
||||||
|
for i, t in enumerate(fake):
|
||||||
|
if i == self.gen_input_for_alteration:
|
||||||
|
altered.append(alteration(t))
|
||||||
|
else:
|
||||||
|
altered.append(t)
|
||||||
|
if self.detach_fake:
|
||||||
|
with torch.no_grad():
|
||||||
|
upsampled_altered = net(*altered)
|
||||||
|
else:
|
||||||
|
upsampled_altered = net(*altered)
|
||||||
|
|
||||||
|
if self.gen_output_to_use:
|
||||||
|
upsampled_altered = upsampled_altered[self.gen_output_to_use]
|
||||||
|
|
||||||
|
# Undo alteration on HR image
|
||||||
|
upsampled_altered = undo_fn(upsampled_altered)
|
||||||
|
|
||||||
|
return self.criterion(state[self.opt['real']], upsampled_altered)
|
33
codes/models/steps/recursive_gen_injectors.py
Normal file
33
codes/models/steps/recursive_gen_injectors.py
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
import models.steps.injectors as injectors
|
||||||
|
|
||||||
|
|
||||||
|
# Uses a generator to synthesize a sequence of images from [in] and injects the results into a list [out]
|
||||||
|
# All results are checkpointed for memory savings. Recurrent inputs are also detached before being fed back into
|
||||||
|
# the generator.
|
||||||
|
class RecurrentImageGeneratorSequenceInjector(injectors.Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env)
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
gen = self.env['generators'][self.opt['generator']]
|
||||||
|
new_state = {}
|
||||||
|
results = []
|
||||||
|
recurrent_input = torch.zeros_like(state[self.input][0])
|
||||||
|
for input in state[self.input]:
|
||||||
|
result = checkpoint(gen, input, recurrent_input)
|
||||||
|
results.append(result)
|
||||||
|
recurrent_input = result.detach()
|
||||||
|
|
||||||
|
new_state = {self.output: results}
|
||||||
|
return new_state
|
||||||
|
|
||||||
|
|
||||||
|
class ImageFlowInjector(injectors.Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
# Requires building this custom cuda kernel. Only require it if explicitly needed.
|
||||||
|
from models.networks.layers.resample2d_package.resample2d import Resample2d
|
||||||
|
super(ImageFlowInjector, self).__init__(opt, env)
|
||||||
|
self.resample = Resample2d()
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
return self.resample(state[self.opt['in']], state[self.opt['flow']])
|
Loading…
Reference in New Issue
Block a user