2020-10-31 17:08:55 +00:00
|
|
|
import random
|
|
|
|
|
2020-08-22 14:24:34 +00:00
|
|
|
import torch.nn
|
2020-10-22 20:39:19 +00:00
|
|
|
from torch.cuda.amp import autocast
|
|
|
|
|
2020-10-14 02:56:39 +00:00
|
|
|
from utils.weight_scheduler import get_scheduler_for_opt
|
2020-12-18 16:18:34 +00:00
|
|
|
from trainer.losses import extract_params_from_state
|
2020-08-22 14:24:34 +00:00
|
|
|
|
|
|
|
# Injectors are a way to sythesize data within a step that can then be used (and reused) by loss functions.
|
|
|
|
def create_injector(opt_inject, env):
|
|
|
|
type = opt_inject['type']
|
2020-10-07 15:02:42 +00:00
|
|
|
if 'teco_' in type:
|
2020-12-18 16:18:34 +00:00
|
|
|
from trainer.custom_training_components import create_teco_injector
|
2020-10-07 15:02:42 +00:00
|
|
|
return create_teco_injector(opt_inject, env)
|
2020-10-18 04:54:12 +00:00
|
|
|
elif 'progressive_' in type:
|
2020-12-18 16:18:34 +00:00
|
|
|
from trainer.custom_training_components import create_progressive_zoom_injector
|
2020-10-18 04:54:12 +00:00
|
|
|
return create_progressive_zoom_injector(opt_inject, env)
|
2020-10-24 17:56:39 +00:00
|
|
|
elif 'stereoscopic_' in type:
|
2020-12-18 16:18:34 +00:00
|
|
|
from trainer.custom_training_components import create_stereoscopic_injector
|
2020-10-24 17:56:39 +00:00
|
|
|
return create_stereoscopic_injector(opt_inject, env)
|
2020-12-03 22:32:21 +00:00
|
|
|
elif 'igpt' in type:
|
2020-12-18 16:24:31 +00:00
|
|
|
from models.transformers.igpt import gpt2
|
2020-12-03 22:32:21 +00:00
|
|
|
return gpt2.create_injector(opt_inject, env)
|
2020-10-07 15:02:42 +00:00
|
|
|
elif type == 'generator':
|
2020-08-23 23:22:34 +00:00
|
|
|
return ImageGeneratorInjector(opt_inject, env)
|
2020-09-17 19:30:32 +00:00
|
|
|
elif type == 'discriminator':
|
|
|
|
return DiscriminatorInjector(opt_inject, env)
|
2020-08-23 23:22:34 +00:00
|
|
|
elif type == 'scheduled_scalar':
|
|
|
|
return ScheduledScalarInjector(opt_inject, env)
|
2020-08-22 19:08:33 +00:00
|
|
|
elif type == 'add_noise':
|
|
|
|
return AddNoiseInjector(opt_inject, env)
|
|
|
|
elif type == 'greyscale':
|
|
|
|
return GreyInjector(opt_inject, env)
|
2020-09-03 17:32:47 +00:00
|
|
|
elif type == 'interpolate':
|
|
|
|
return InterpolateInjector(opt_inject, env)
|
2020-09-27 03:25:32 +00:00
|
|
|
elif type == 'image_patch':
|
|
|
|
return ImagePatchInjector(opt_inject, env)
|
2020-10-07 15:02:42 +00:00
|
|
|
elif type == 'concatenate':
|
|
|
|
return ConcatenateInjector(opt_inject, env)
|
2020-10-10 02:35:56 +00:00
|
|
|
elif type == 'margin_removal':
|
|
|
|
return MarginRemoval(opt_inject, env)
|
2020-10-11 04:39:55 +00:00
|
|
|
elif type == 'foreach':
|
|
|
|
return ForEachInjector(opt_inject, env)
|
2020-10-11 14:20:07 +00:00
|
|
|
elif type == 'constant':
|
|
|
|
return ConstantInjector(opt_inject, env)
|
2020-10-24 17:56:39 +00:00
|
|
|
elif type == 'extract_indices':
|
|
|
|
return IndicesExtractor(opt_inject, env)
|
2020-10-31 17:08:55 +00:00
|
|
|
elif type == 'random_shift':
|
|
|
|
return RandomShiftInjector(opt_inject, env)
|
|
|
|
elif type == 'batch_rotate':
|
|
|
|
return BatchRotateInjector(opt_inject, env)
|
2020-11-14 03:11:50 +00:00
|
|
|
elif type == 'sr_diffs':
|
|
|
|
return SrDiffsInjector(opt_inject, env)
|
2020-11-29 22:39:50 +00:00
|
|
|
elif type == 'multiframe_combiner':
|
|
|
|
return MultiFrameCombiner(opt_inject, env)
|
2020-08-22 14:24:34 +00:00
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
class Injector(torch.nn.Module):
|
|
|
|
def __init__(self, opt, env):
|
2020-08-22 19:08:33 +00:00
|
|
|
super(Injector, self).__init__()
|
2020-08-22 14:24:34 +00:00
|
|
|
self.opt = opt
|
|
|
|
self.env = env
|
2020-08-23 23:22:34 +00:00
|
|
|
if 'in' in opt.keys():
|
|
|
|
self.input = opt['in']
|
2020-08-22 14:24:34 +00:00
|
|
|
self.output = opt['out']
|
|
|
|
|
|
|
|
# This should return a dict of new state variables.
|
|
|
|
def forward(self, state):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2020-08-23 23:22:34 +00:00
|
|
|
# Uses a generator to synthesize an image from [in] and injects the results into [out]
|
|
|
|
# Note that results are *not* detached.
|
|
|
|
class ImageGeneratorInjector(Injector):
|
|
|
|
def __init__(self, opt, env):
|
|
|
|
super(ImageGeneratorInjector, self).__init__(opt, env)
|
2020-11-23 18:31:11 +00:00
|
|
|
self.grad = opt['grad'] if 'grad' in opt.keys() else True
|
2020-08-23 23:22:34 +00:00
|
|
|
|
|
|
|
def forward(self, state):
|
|
|
|
gen = self.env['generators'][self.opt['generator']]
|
2020-10-22 20:39:19 +00:00
|
|
|
with autocast(enabled=self.env['opt']['fp16']):
|
|
|
|
if isinstance(self.input, list):
|
|
|
|
params = extract_params_from_state(self.input, state)
|
2020-11-23 18:31:11 +00:00
|
|
|
else:
|
|
|
|
params = [state[self.input]]
|
|
|
|
if self.grad:
|
2020-10-22 20:39:19 +00:00
|
|
|
results = gen(*params)
|
|
|
|
else:
|
2020-11-23 18:31:11 +00:00
|
|
|
with torch.no_grad():
|
|
|
|
results = gen(*params)
|
2020-08-23 23:22:34 +00:00
|
|
|
new_state = {}
|
|
|
|
if isinstance(self.output, list):
|
2020-09-12 04:57:06 +00:00
|
|
|
# Only dereference tuples or lists, not tensors.
|
|
|
|
assert isinstance(results, list) or isinstance(results, tuple)
|
2020-08-23 23:22:34 +00:00
|
|
|
for i, k in enumerate(self.output):
|
|
|
|
new_state[k] = results[i]
|
|
|
|
else:
|
|
|
|
new_state[self.output] = results
|
|
|
|
|
|
|
|
return new_state
|
|
|
|
|
|
|
|
|
2020-09-17 19:30:32 +00:00
|
|
|
# Injects a result from a discriminator network into the state.
|
|
|
|
class DiscriminatorInjector(Injector):
|
|
|
|
def __init__(self, opt, env):
|
|
|
|
super(DiscriminatorInjector, self).__init__(opt, env)
|
|
|
|
|
|
|
|
def forward(self, state):
|
2020-10-31 17:08:55 +00:00
|
|
|
with autocast(enabled=self.env['opt']['fp16']):
|
|
|
|
d = self.env['discriminators'][self.opt['discriminator']]
|
|
|
|
if isinstance(self.input, list):
|
|
|
|
params = [state[i] for i in self.input]
|
|
|
|
results = d(*params)
|
|
|
|
else:
|
|
|
|
results = d(state[self.input])
|
2020-09-17 19:30:32 +00:00
|
|
|
new_state = {}
|
|
|
|
if isinstance(self.output, list):
|
|
|
|
# Only dereference tuples or lists, not tensors.
|
|
|
|
assert isinstance(results, list) or isinstance(results, tuple)
|
|
|
|
for i, k in enumerate(self.output):
|
|
|
|
new_state[k] = results[i]
|
|
|
|
else:
|
|
|
|
new_state[self.output] = results
|
|
|
|
|
|
|
|
return new_state
|
|
|
|
|
|
|
|
|
2020-08-23 23:22:34 +00:00
|
|
|
# Injects a scalar that is modulated with a specified schedule. Useful for increasing or decreasing the influence
|
|
|
|
# of something over time.
|
|
|
|
class ScheduledScalarInjector(Injector):
|
|
|
|
def __init__(self, opt, env):
|
|
|
|
super(ScheduledScalarInjector, self).__init__(opt, env)
|
|
|
|
self.scheduler = get_scheduler_for_opt(opt['scheduler'])
|
|
|
|
|
|
|
|
def forward(self, state):
|
|
|
|
return {self.opt['out']: self.scheduler.get_weight_for_step(self.env['step'])}
|
|
|
|
|
|
|
|
|
2020-08-22 19:08:33 +00:00
|
|
|
# Adds gaussian noise to [in], scales it to [0,[scale]] and injects into [out]
|
|
|
|
class AddNoiseInjector(Injector):
|
|
|
|
def __init__(self, opt, env):
|
|
|
|
super(AddNoiseInjector, self).__init__(opt, env)
|
2020-11-12 22:42:05 +00:00
|
|
|
self.mode = opt['mode'] if 'mode' in opt.keys() else 'normal'
|
2020-08-22 19:08:33 +00:00
|
|
|
|
|
|
|
def forward(self, state):
|
2020-08-23 23:22:34 +00:00
|
|
|
# Scale can be a fixed float, or a state key (e.g. from ScheduledScalarInjector).
|
|
|
|
if isinstance(self.opt['scale'], str):
|
|
|
|
scale = state[self.opt['scale']]
|
|
|
|
else:
|
|
|
|
scale = self.opt['scale']
|
2020-11-20 06:47:24 +00:00
|
|
|
if scale is None:
|
|
|
|
scale = 1
|
2020-08-23 23:22:34 +00:00
|
|
|
|
2020-11-12 22:42:05 +00:00
|
|
|
ref = state[self.opt['in']]
|
|
|
|
if self.mode == 'normal':
|
|
|
|
noise = torch.randn_like(ref) * scale
|
|
|
|
elif self.mode == 'uniform':
|
|
|
|
noise = torch.FloatTensor(ref.shape).uniform_(0.0, scale).to(ref.device)
|
2020-08-22 19:08:33 +00:00
|
|
|
return {self.opt['out']: state[self.opt['in']] + noise}
|
|
|
|
|
|
|
|
|
|
|
|
# Averages the channel dimension (1) of [in] and saves to [out]. Dimensions are
|
|
|
|
# kept the same, the average is simply repeated.
|
|
|
|
class GreyInjector(Injector):
|
|
|
|
def __init__(self, opt, env):
|
|
|
|
super(GreyInjector, self).__init__(opt, env)
|
|
|
|
|
|
|
|
def forward(self, state):
|
|
|
|
mean = torch.mean(state[self.opt['in']], dim=1, keepdim=True)
|
2020-09-03 17:32:47 +00:00
|
|
|
mean = mean.repeat(1, 3, 1, 1)
|
2020-08-23 23:22:34 +00:00
|
|
|
return {self.opt['out']: mean}
|
2020-09-03 17:32:47 +00:00
|
|
|
|
2020-09-27 03:25:32 +00:00
|
|
|
|
2020-09-03 17:32:47 +00:00
|
|
|
class InterpolateInjector(Injector):
|
|
|
|
def __init__(self, opt, env):
|
|
|
|
super(InterpolateInjector, self).__init__(opt, env)
|
2020-09-30 18:01:00 +00:00
|
|
|
if 'scale_factor' in opt.keys():
|
|
|
|
self.scale_factor = opt['scale_factor']
|
|
|
|
self.size = None
|
|
|
|
else:
|
|
|
|
self.scale_factor = None
|
|
|
|
self.size = (opt['size'], opt['size'])
|
2020-09-03 17:32:47 +00:00
|
|
|
|
|
|
|
def forward(self, state):
|
|
|
|
scaled = torch.nn.functional.interpolate(state[self.opt['in']], scale_factor=self.opt['scale_factor'],
|
2020-09-30 18:01:00 +00:00
|
|
|
size=self.opt['size'], mode=self.opt['mode'])
|
2020-09-19 16:07:00 +00:00
|
|
|
return {self.opt['out']: scaled}
|
2020-09-27 03:25:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
# Extracts four patches from the input image, each a square of 'patch_size'. The input images are taken from each
|
|
|
|
# of the four corners of the image. The intent of this loss is that each patch shares some part of the input, which
|
|
|
|
# can then be used in the translation invariance loss.
|
|
|
|
#
|
|
|
|
# This injector is unique in that it does not only produce the specified output label into state. Instead it produces five
|
|
|
|
# outputs for the specified label, one for each corner of the input as well as the specified output, which is the top left
|
|
|
|
# corner. See the code below to find out how this works.
|
|
|
|
#
|
|
|
|
# Another note: this injector operates differently in eval mode (e.g. when env['training']=False) - in this case, it
|
|
|
|
# simply sets all the output state variables to the input. This is so that you can feed the output of this injector
|
|
|
|
# directly into your generator in training without affecting test performance.
|
|
|
|
class ImagePatchInjector(Injector):
|
|
|
|
def __init__(self, opt, env):
|
|
|
|
super(ImagePatchInjector, self).__init__(opt, env)
|
|
|
|
self.patch_size = opt['patch_size']
|
2020-10-14 02:44:51 +00:00
|
|
|
self.resize = opt['resize'] if 'resize' in opt.keys() else None # If specified, the output is resized to a square with this size after patch extraction.
|
2020-09-27 03:25:32 +00:00
|
|
|
|
|
|
|
def forward(self, state):
|
|
|
|
im = state[self.opt['in']]
|
|
|
|
if self.env['training']:
|
2020-10-14 02:44:51 +00:00
|
|
|
res = { self.opt['out']: im[:, :3, :self.patch_size, :self.patch_size],
|
2020-09-30 18:01:00 +00:00
|
|
|
'%s_top_left' % (self.opt['out'],): im[:, :, :self.patch_size, :self.patch_size],
|
|
|
|
'%s_top_right' % (self.opt['out'],): im[:, :, :self.patch_size, -self.patch_size:],
|
|
|
|
'%s_bottom_left' % (self.opt['out'],): im[:, :, -self.patch_size:, :self.patch_size],
|
|
|
|
'%s_bottom_right' % (self.opt['out'],): im[:, :, -self.patch_size:, -self.patch_size:] }
|
2020-09-27 03:25:32 +00:00
|
|
|
else:
|
2020-10-14 02:44:51 +00:00
|
|
|
res = { self.opt['out']: im,
|
2020-09-27 03:25:32 +00:00
|
|
|
'%s_top_left' % (self.opt['out'],): im,
|
|
|
|
'%s_top_right' % (self.opt['out'],): im,
|
|
|
|
'%s_bottom_left' % (self.opt['out'],): im,
|
|
|
|
'%s_bottom_right' % (self.opt['out'],): im }
|
2020-10-14 02:44:51 +00:00
|
|
|
if self.resize is not None:
|
|
|
|
res2 = {}
|
|
|
|
for k, v in res.items():
|
|
|
|
res2[k] = torch.nn.functional.interpolate(v, size=(self.resize, self.resize), mode="nearest")
|
|
|
|
res = res2
|
|
|
|
return res
|
2020-10-07 15:02:42 +00:00
|
|
|
|
|
|
|
|
|
|
|
# Concatenates a list of tensors on the specified dimension.
|
|
|
|
class ConcatenateInjector(Injector):
|
|
|
|
def __init__(self, opt, env):
|
|
|
|
super(ConcatenateInjector, self).__init__(opt, env)
|
|
|
|
self.dim = opt['dim']
|
|
|
|
|
|
|
|
def forward(self, state):
|
|
|
|
input = [state[i] for i in self.input]
|
2020-10-10 02:35:56 +00:00
|
|
|
return {self.opt['out']: torch.cat(input, dim=self.dim)}
|
|
|
|
|
|
|
|
|
|
|
|
# Removes margins from an image.
|
|
|
|
class MarginRemoval(Injector):
|
|
|
|
def __init__(self, opt, env):
|
|
|
|
super(MarginRemoval, self).__init__(opt, env)
|
|
|
|
self.margin = opt['margin']
|
2020-10-31 17:08:55 +00:00
|
|
|
self.random_shift_max = opt['random_shift_max'] if 'random_shift_max' in opt.keys() else 0
|
2020-10-10 02:35:56 +00:00
|
|
|
|
|
|
|
def forward(self, state):
|
|
|
|
input = state[self.input]
|
2020-10-31 17:08:55 +00:00
|
|
|
if self.random_shift_max > 0:
|
|
|
|
output = []
|
|
|
|
# This is a really shitty way of doing this. If it works at all, I should reconsider using Resample2D, for example.
|
|
|
|
for b in range(input.shape[0]):
|
|
|
|
shiftleft = random.randint(-self.random_shift_max, self.random_shift_max)
|
|
|
|
shifttop = random.randint(-self.random_shift_max, self.random_shift_max)
|
|
|
|
output.append(input[b, :, self.margin+shiftleft:-(self.margin-shiftleft),
|
|
|
|
self.margin+shifttop:-(self.margin-shifttop)])
|
|
|
|
output = torch.stack(output, dim=0)
|
|
|
|
else:
|
|
|
|
output = input[:, :, self.margin:-self.margin,
|
|
|
|
self.margin:-self.margin]
|
|
|
|
|
|
|
|
return {self.opt['out']: output}
|
|
|
|
|
2020-10-11 03:50:23 +00:00
|
|
|
|
2020-10-11 04:39:55 +00:00
|
|
|
# Produces an injection which is composed of applying a single injector multiple times across a single dimension.
|
|
|
|
class ForEachInjector(Injector):
|
2020-10-11 03:50:23 +00:00
|
|
|
def __init__(self, opt, env):
|
2020-10-11 04:39:55 +00:00
|
|
|
super(ForEachInjector, self).__init__(opt, env)
|
|
|
|
o = opt.copy()
|
|
|
|
o['type'] = opt['subtype']
|
|
|
|
o['in'] = '_in'
|
|
|
|
o['out'] = '_out'
|
|
|
|
self.injector = create_injector(o, self.env)
|
2020-11-24 16:24:02 +00:00
|
|
|
self.aslist = opt['aslist'] if 'aslist' in opt.keys() else False
|
2020-10-11 03:50:23 +00:00
|
|
|
|
|
|
|
def forward(self, state):
|
2020-10-11 04:39:55 +00:00
|
|
|
injs = []
|
|
|
|
st = state.copy()
|
|
|
|
inputs = state[self.opt['in']]
|
|
|
|
for i in range(inputs.shape[1]):
|
|
|
|
st['_in'] = inputs[:, i]
|
|
|
|
injs.append(self.injector(st)['_out'])
|
2020-11-24 16:24:02 +00:00
|
|
|
if self.aslist:
|
|
|
|
return {self.output: injs}
|
|
|
|
else:
|
|
|
|
return {self.output: torch.stack(injs, dim=1)}
|
2020-10-11 14:20:07 +00:00
|
|
|
|
|
|
|
|
|
|
|
class ConstantInjector(Injector):
|
|
|
|
def __init__(self, opt, env):
|
|
|
|
super(ConstantInjector, self).__init__(opt, env)
|
|
|
|
self.constant_type = opt['constant_type']
|
|
|
|
self.like = opt['like'] # This injector uses this tensor to determine what batch size and device to use.
|
|
|
|
|
|
|
|
def forward(self, state):
|
2020-10-12 16:36:30 +00:00
|
|
|
like = state[self.like]
|
2020-10-11 14:20:07 +00:00
|
|
|
if self.constant_type == 'zeroes':
|
2020-10-12 16:36:30 +00:00
|
|
|
out = torch.zeros_like(like)
|
2020-10-11 14:20:07 +00:00
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
|
|
|
return { self.opt['out']: out }
|
2020-10-22 04:22:00 +00:00
|
|
|
|
|
|
|
|
2020-10-24 17:56:39 +00:00
|
|
|
class IndicesExtractor(Injector):
|
|
|
|
def __init__(self, opt, env):
|
|
|
|
super(IndicesExtractor, self).__init__(opt, env)
|
|
|
|
self.dim = opt['dim']
|
|
|
|
assert self.dim == 1 # Honestly not sure how to support an abstract dim here, so just add yours when needed.
|
|
|
|
|
|
|
|
def forward(self, state):
|
|
|
|
results = {}
|
|
|
|
for i, o in enumerate(self.output):
|
|
|
|
if self.dim == 1:
|
|
|
|
results[o] = state[self.input][:, i]
|
|
|
|
return results
|
|
|
|
|
2020-10-31 17:08:55 +00:00
|
|
|
|
|
|
|
class RandomShiftInjector(Injector):
|
|
|
|
def __init__(self, opt, env):
|
|
|
|
super(RandomShiftInjector, self).__init__(opt, env)
|
|
|
|
|
|
|
|
def forward(self, state):
|
|
|
|
img = state[self.input]
|
|
|
|
return {self.output: img}
|
|
|
|
|
|
|
|
|
|
|
|
class BatchRotateInjector(Injector):
|
|
|
|
def __init__(self, opt, env):
|
|
|
|
super(BatchRotateInjector, self).__init__(opt, env)
|
|
|
|
|
|
|
|
def forward(self, state):
|
|
|
|
img = state[self.input]
|
|
|
|
return {self.output: torch.roll(img, 1, 0)}
|
|
|
|
|
2020-11-14 03:11:50 +00:00
|
|
|
|
|
|
|
# Injector used to work with image deltas used in diff-SR
|
|
|
|
class SrDiffsInjector(Injector):
|
|
|
|
def __init__(self, opt, env):
|
|
|
|
super(SrDiffsInjector, self).__init__(opt, env)
|
|
|
|
self.mode = opt['mode']
|
|
|
|
assert self.mode in ['recombine', 'produce_diff']
|
|
|
|
self.lq = opt['lq']
|
|
|
|
self.hq = opt['hq']
|
|
|
|
if self.mode == 'produce_diff':
|
|
|
|
self.diff_key = opt['diff']
|
2020-11-15 23:16:18 +00:00
|
|
|
self.include_combined = opt['include_combined']
|
2020-11-14 03:11:50 +00:00
|
|
|
|
|
|
|
def forward(self, state):
|
|
|
|
resampled_lq = state[self.lq]
|
|
|
|
hq = state[self.hq]
|
|
|
|
if self.mode == 'produce_diff':
|
|
|
|
diff = hq - resampled_lq
|
2020-11-15 23:16:18 +00:00
|
|
|
if self.include_combined:
|
|
|
|
res = torch.cat([resampled_lq, diff, hq], dim=1)
|
|
|
|
else:
|
|
|
|
res = torch.cat([resampled_lq, diff], dim=1)
|
|
|
|
return {self.output: res,
|
2020-11-14 03:11:50 +00:00
|
|
|
self.diff_key: diff}
|
|
|
|
elif self.mode == 'recombine':
|
|
|
|
combined = resampled_lq + hq
|
|
|
|
return {self.output: combined}
|
2020-11-29 22:39:50 +00:00
|
|
|
|
|
|
|
|
|
|
|
class MultiFrameCombiner(Injector):
|
|
|
|
def __init__(self, opt, env):
|
|
|
|
super().__init__(opt, env)
|
|
|
|
self.mode = opt['mode']
|
|
|
|
self.dim = opt['dim'] if 'dim' in opt.keys() else None
|
|
|
|
self.flow = opt['flow']
|
|
|
|
self.in_lq_key = opt['in']
|
|
|
|
self.in_hq_key = opt['in_hq']
|
|
|
|
self.out_lq_key = opt['out']
|
|
|
|
self.out_hq_key = opt['out_hq']
|
2020-12-18 16:24:31 +00:00
|
|
|
from models.flownet2.networks import Resample2d
|
2020-11-29 22:39:50 +00:00
|
|
|
self.resampler = Resample2d()
|
|
|
|
|
|
|
|
def combine(self, state):
|
|
|
|
flow = self.env['generators'][self.flow]
|
|
|
|
lq = state[self.in_lq_key]
|
|
|
|
hq = state[self.in_hq_key]
|
|
|
|
b, f, c, h, w = lq.shape
|
|
|
|
center = f // 2
|
|
|
|
center_img = lq[:,center,:,:,:]
|
|
|
|
imgs = [center_img]
|
|
|
|
with torch.no_grad():
|
|
|
|
for i in range(f):
|
|
|
|
if i == center:
|
|
|
|
continue
|
|
|
|
nimg = lq[:,i,:,:,:]
|
|
|
|
flowfield = flow(torch.stack([center_img, nimg], dim=2).float())
|
|
|
|
nimg = self.resampler(nimg, flowfield)
|
|
|
|
imgs.append(nimg)
|
|
|
|
hq_out = hq[:,center,:,:,:]
|
|
|
|
return {self.out_lq_key: torch.cat(imgs, dim=1),
|
|
|
|
self.out_hq_key: hq_out,
|
|
|
|
self.out_lq_key + "_flow_sample": torch.cat(imgs, dim=0)}
|
|
|
|
|
|
|
|
def synthesize(self, state):
|
|
|
|
lq = state[self.in_lq_key]
|
|
|
|
return {
|
|
|
|
self.out_lq_key: lq.repeat(1, self.dim, 1, 1)
|
|
|
|
}
|
|
|
|
|
|
|
|
def forward(self, state):
|
|
|
|
if self.mode == "synthesize":
|
|
|
|
return self.synthesize(state)
|
|
|
|
elif self.mode == "combine":
|
|
|
|
return self.combine(state)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError
|