diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index 304e58e8..9f320506 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -67,6 +67,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector): self.resample = Resample2d() self.first_inputs = opt['first_inputs'] if 'first_inputs' in opt.keys() else opt['in'] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'. self.do_backwards = opt['do_backwards'] if 'do_backwards' in opt.keys() else True + self.hq_recurrent = opt['hq_recurrent'] if 'hq_recurrent' in opt.keys() else False # When True, recurrent_index is not touched for the first iteration, allowing you to specify what is fed in. When False, zeros are fed into the recurrent index. def forward(self, state): gen = self.env['generators'][self.opt['generator']] @@ -89,7 +90,10 @@ class RecurrentImageGeneratorSequenceInjector(Injector): for i in range(f): if first_step: input = extract_inputs_index(first_inputs, i) - recurrent_input = torch.zeros_like(input[self.recurrent_index]) + if self.hq_recurrent: + recurrent_input = input[self.recurrent_index] + else: + recurrent_input = torch.zeros_like(input[self.recurrent_index]) first_step = False else: input = extract_inputs_index(inputs, i)