diff --git a/codes/models/steps/progressive_zoom.py b/codes/models/steps/progressive_zoom.py index c516e0bb..3da24320 100644 --- a/codes/models/steps/progressive_zoom.py +++ b/codes/models/steps/progressive_zoom.py @@ -103,8 +103,10 @@ class ProgressiveGeneratorInjector(Injector): self.produce_progressive_visual_debugs(chain_input, chain_output, debug_index) debug_index += 1 results[self.hq_output_key] = results_hq + + # Results are concatenated into the batch dimension, to allow normal losses to be used against the output. for k, v in results.items(): - results[k] = torch.stack(v, dim=1) + results[k] = torch.cat(v, dim=0) return results