DL-Art-School/codes/models/steps/injectors.py

59 lines
2.0 KiB
Python
Raw Normal View History

2020-08-22 14:24:34 +00:00
import torch.nn
from models.archs.SPSR_arch import ImageGradientNoPadding
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
def create_injector(opt_inject, env):
type = opt_inject['type']
if type == 'img_grad':
return ImageGradientInjector(opt_inject, env)
2020-08-22 19:08:33 +00:00
elif type == 'add_noise':
return AddNoiseInjector(opt_inject, env)
elif type == 'greyscale':
return GreyInjector(opt_inject, env)
2020-08-22 14:24:34 +00:00
else:
raise NotImplementedError
class Injector(torch.nn.Module):
def __init__(self, opt, env):
2020-08-22 19:08:33 +00:00
super(Injector, self).__init__()
2020-08-22 14:24:34 +00:00
self.opt = opt
self.env = env
self.input = opt['in']
self.output = opt['out']
# This should return a dict of new state variables.
def forward(self, state):
raise NotImplementedError
2020-08-22 19:08:33 +00:00
# Creates an image gradient from [in] and injects it into [out]
2020-08-22 14:24:34 +00:00
class ImageGradientInjector(Injector):
def __init__(self, opt, env):
2020-08-22 19:08:33 +00:00
super(ImageGradientInjector, self).__init__(opt, env)
2020-08-22 14:24:34 +00:00
self.img_grad_fn = ImageGradientNoPadding()
def forward(self, state):
2020-08-22 19:08:33 +00:00
return {self.opt['out']: self.img_grad_fn(state[self.opt['in']])}
# Adds gaussian noise to [in], scales it to [0,[scale]] and injects into [out]
class AddNoiseInjector(Injector):
def __init__(self, opt, env):
super(AddNoiseInjector, self).__init__(opt, env)
def forward(self, state):
noise = torch.randn_like(state[self.opt['in']]) * self.opt['scale']
return {self.opt['out']: state[self.opt['in']] + noise}
# Averages the channel dimension (1) of [in] and saves to [out]. Dimensions are
# kept the same, the average is simply repeated.
class GreyInjector(Injector):
def __init__(self, opt, env):
super(GreyInjector, self).__init__(opt, env)
def forward(self, state):
mean = torch.mean(state[self.opt['in']], dim=1, keepdim=True)
mean = torch.repeat(mean, (-1, 3, -1, -1))
return {self.opt['out']: mean}