Collapse progressive zoom candidates into the batch dimension

This contributes a significant speedup to training this type of network
since losses can operate on the entire prediction spectrum at once.
This commit is contained in:
James Betker 2020-10-21 22:37:23 -06:00
parent 680d635420
commit 43c4f92123

View File

@ -103,8 +103,10 @@ class ProgressiveGeneratorInjector(Injector):
self.produce_progressive_visual_debugs(chain_input, chain_output, debug_index) self.produce_progressive_visual_debugs(chain_input, chain_output, debug_index)
debug_index += 1 debug_index += 1
results[self.hq_output_key] = results_hq results[self.hq_output_key] = results_hq
# Results are concatenated into the batch dimension, to allow normal losses to be used against the output.
for k, v in results.items(): for k, v in results.items():
results[k] = torch.stack(v, dim=1) results[k] = torch.cat(v, dim=0)
return results return results