Add constant injector
This commit is contained in:
parent
f99812e14d
commit
120072d464
|
@ -34,6 +34,8 @@ def create_injector(opt_inject, env):
|
|||
return ConcatenateInjector(opt_inject, env)
|
||||
elif type == 'margin_removal':
|
||||
return MarginRemoval(opt_inject, env)
|
||||
elif type == 'constant':
|
||||
return ConstantInjector(opt_inject, env)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -216,4 +218,21 @@ class MarginRemoval(Injector):
|
|||
|
||||
def forward(self, state):
|
||||
input = state[self.input]
|
||||
return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]}
|
||||
return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]}
|
||||
|
||||
|
||||
class ConstantInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(ConstantInjector, self).__init__(opt, env)
|
||||
self.constant_type = opt['constant_type']
|
||||
self.dim = opt['dim']
|
||||
self.like = opt['like'] # This injector uses this tensor to determine what batch size and device to use.
|
||||
|
||||
def forward(self, state):
|
||||
bs = state[self.like].shape[0]
|
||||
dev = state[self.like].device
|
||||
if self.constant_type == 'zeros':
|
||||
out = torch.zeros((bs,) + tuple(self.dim), device=dev)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return { self.opt['out']: out }
|
||||
|
|
|
@ -93,7 +93,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
||||
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
|
||||
# Resample does not work in FP16.
|
||||
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
|
||||
recurrent_input = self.resample(recurrent_input.float(), flowfield.float())
|
||||
input[self.recurrent_index] = recurrent_input
|
||||
if self.env['step'] % 50 == 0:
|
||||
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
|
||||
|
@ -113,11 +113,11 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
|
||||
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
||||
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
|
||||
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
|
||||
recurrent_input = self.resample(recurrent_input.float(), flowfield.float())
|
||||
input[self.recurrent_index
|
||||
] = recurrent_input
|
||||
if self.env['step'] % 50 == 0:
|
||||
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.self.recurrent_index], debug_index)
|
||||
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
|
||||
debug_index += 1
|
||||
gen_out = gen(*input)
|
||||
if isinstance(gen_out, torch.Tensor):
|
||||
|
|
Loading…
Reference in New Issue
Block a user