Add constant injector
This commit is contained in:
parent
f99812e14d
commit
120072d464
|
@ -34,6 +34,8 @@ def create_injector(opt_inject, env):
|
||||||
return ConcatenateInjector(opt_inject, env)
|
return ConcatenateInjector(opt_inject, env)
|
||||||
elif type == 'margin_removal':
|
elif type == 'margin_removal':
|
||||||
return MarginRemoval(opt_inject, env)
|
return MarginRemoval(opt_inject, env)
|
||||||
|
elif type == 'constant':
|
||||||
|
return ConstantInjector(opt_inject, env)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -216,4 +218,21 @@ class MarginRemoval(Injector):
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
input = state[self.input]
|
input = state[self.input]
|
||||||
return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]}
|
return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]}
|
||||||
|
|
||||||
|
|
||||||
|
class ConstantInjector(Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super(ConstantInjector, self).__init__(opt, env)
|
||||||
|
self.constant_type = opt['constant_type']
|
||||||
|
self.dim = opt['dim']
|
||||||
|
self.like = opt['like'] # This injector uses this tensor to determine what batch size and device to use.
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
bs = state[self.like].shape[0]
|
||||||
|
dev = state[self.like].device
|
||||||
|
if self.constant_type == 'zeros':
|
||||||
|
out = torch.zeros((bs,) + tuple(self.dim), device=dev)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return { self.opt['out']: out }
|
||||||
|
|
|
@ -93,7 +93,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
||||||
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
|
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
|
||||||
# Resample does not work in FP16.
|
# Resample does not work in FP16.
|
||||||
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
|
recurrent_input = self.resample(recurrent_input.float(), flowfield.float())
|
||||||
input[self.recurrent_index] = recurrent_input
|
input[self.recurrent_index] = recurrent_input
|
||||||
if self.env['step'] % 50 == 0:
|
if self.env['step'] % 50 == 0:
|
||||||
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
|
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
|
||||||
|
@ -113,11 +113,11 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
|
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
|
||||||
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
||||||
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
|
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
|
||||||
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
|
recurrent_input = self.resample(recurrent_input.float(), flowfield.float())
|
||||||
input[self.recurrent_index
|
input[self.recurrent_index
|
||||||
] = recurrent_input
|
] = recurrent_input
|
||||||
if self.env['step'] % 50 == 0:
|
if self.env['step'] % 50 == 0:
|
||||||
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.self.recurrent_index], debug_index)
|
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
|
||||||
debug_index += 1
|
debug_index += 1
|
||||||
gen_out = gen(*input)
|
gen_out = gen(*input)
|
||||||
if isinstance(gen_out, torch.Tensor):
|
if isinstance(gen_out, torch.Tensor):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user