diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index 7d0584cc..f0d26185 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -34,6 +34,8 @@ def create_injector(opt_inject, env): return ConcatenateInjector(opt_inject, env) elif type == 'margin_removal': return MarginRemoval(opt_inject, env) + elif type == 'constant': + return ConstantInjector(opt_inject, env) else: raise NotImplementedError @@ -216,4 +218,21 @@ class MarginRemoval(Injector): def forward(self, state): input = state[self.input] - return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]} \ No newline at end of file + return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]} + + +class ConstantInjector(Injector): + def __init__(self, opt, env): + super(ConstantInjector, self).__init__(opt, env) + self.constant_type = opt['constant_type'] + self.dim = opt['dim'] + self.like = opt['like'] # This injector uses this tensor to determine what batch size and device to use. + + def forward(self, state): + bs = state[self.like].shape[0] + dev = state[self.like].device + if self.constant_type == 'zeros': + out = torch.zeros((bs,) + tuple(self.dim), device=dev) + else: + raise NotImplementedError + return { self.opt['out']: out } diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index 5fe1a8ec..7f50f3e9 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -93,7 +93,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector): flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2) flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic') # Resample does not work in FP16. - recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float()) + recurrent_input = self.resample(recurrent_input.float(), flowfield.float()) input[self.recurrent_index] = recurrent_input if self.env['step'] % 50 == 0: self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index) @@ -113,11 +113,11 @@ class RecurrentImageGeneratorSequenceInjector(Injector): reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic') flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2) flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic') - recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float()) + recurrent_input = self.resample(recurrent_input.float(), flowfield.float()) input[self.recurrent_index ] = recurrent_input if self.env['step'] % 50 == 0: - self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.self.recurrent_index], debug_index) + self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index) debug_index += 1 gen_out = gen(*input) if isinstance(gen_out, torch.Tensor):