diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index b08e253d..91fd08bd 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -50,6 +50,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector): self.first_inputs = opt['first_inputs'] if 'first_inputs' in opt.keys() else opt['in'] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'. self.do_backwards = opt['do_backwards'] if 'do_backwards' in opt.keys() else True self.hq_recurrent = opt['hq_recurrent'] if 'hq_recurrent' in opt.keys() else False # When True, recurrent_index is not touched for the first iteration, allowing you to specify what is fed in. When False, zeros are fed into the recurrent index. + self.hq_batched_output_key = opt['hq_batched_key'] if 'hq_batched_key' in opt.keys() else None def forward(self, state): gen = self.env['generators'][self.opt['generator']] @@ -138,7 +139,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector): final_results = {} # Include 'hq_batched' here - because why not... Don't really need a separate injector for this. b, s, c, h, w = state['hq'].shape - final_results['hq_batched'] = state['hq'].clone().permute(1,0,2,3,4).reshape(b*s, c, h, w) + if self.hq_batched_output_key is not None: + final_results[self.hq_batched_output_key] = state['hq'].clone().permute(1,0,2,3,4).reshape(b*s, c, h, w) for k, v in results.items(): final_results[k] = torch.stack(v, dim=1) final_results[k + "_batched"] = torch.cat(v[:s], dim=0) # Only include the original sequence - this output is generally used to compare against HQ.