diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index 5857a5f5..ea8b1017 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -95,6 +95,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector): self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index) debug_index += 1 gen_out = gen(*input) + if isinstance(gen_out, torch.Tensor): + gen_out = [gen_out] recurrent_input = gen_out[self.output_hq_index] results.append(recurrent_input) @@ -113,6 +115,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector): self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index) debug_index += 1 gen_out = gen(*input) + if isinstance(gen_out, torch.Tensor): + gen_out = [gen_out] recurrent_input = gen_out[self.output_hq_index] results.append(recurrent_input)