Add constant injector

This commit is contained in:
James Betker 2020-10-10 21:50:23 -06:00
parent f99812e14d
commit 120072d464
2 changed files with 23 additions and 4 deletions

View File

@ -34,6 +34,8 @@ def create_injector(opt_inject, env):
return ConcatenateInjector(opt_inject, env)
elif type == 'margin_removal':
return MarginRemoval(opt_inject, env)
elif type == 'constant':
return ConstantInjector(opt_inject, env)
else:
raise NotImplementedError
@ -216,4 +218,21 @@ class MarginRemoval(Injector):
def forward(self, state):
input = state[self.input]
return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]}
return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]}
class ConstantInjector(Injector):
def __init__(self, opt, env):
super(ConstantInjector, self).__init__(opt, env)
self.constant_type = opt['constant_type']
self.dim = opt['dim']
self.like = opt['like'] # This injector uses this tensor to determine what batch size and device to use.
def forward(self, state):
bs = state[self.like].shape[0]
dev = state[self.like].device
if self.constant_type == 'zeros':
out = torch.zeros((bs,) + tuple(self.dim), device=dev)
else:
raise NotImplementedError
return { self.opt['out']: out }

View File

@ -93,7 +93,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
# Resample does not work in FP16.
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
recurrent_input = self.resample(recurrent_input.float(), flowfield.float())
input[self.recurrent_index] = recurrent_input
if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
@ -113,11 +113,11 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
recurrent_input = self.resample(recurrent_input.float(), flowfield.float())
input[self.recurrent_index
] = recurrent_input
if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.self.recurrent_index], debug_index)
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
debug_index += 1
gen_out = gen(*input)
if isinstance(gen_out, torch.Tensor):