Merge remote-tracking branch 'origin/gan_lab' into gan_lab

This commit is contained in:
James Betker 2020-11-13 20:10:47 -07:00
commit 9c3d0b7560

View File

@ -50,6 +50,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
self.first_inputs = opt['first_inputs'] if 'first_inputs' in opt.keys() else opt['in'] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'.
self.do_backwards = opt['do_backwards'] if 'do_backwards' in opt.keys() else True
self.hq_recurrent = opt['hq_recurrent'] if 'hq_recurrent' in opt.keys() else False # When True, recurrent_index is not touched for the first iteration, allowing you to specify what is fed in. When False, zeros are fed into the recurrent index.
self.hq_batched_output_key = opt['hq_batched_key'] if 'hq_batched_key' in opt.keys() else None
def forward(self, state):
gen = self.env['generators'][self.opt['generator']]
@ -138,7 +139,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
final_results = {}
# Include 'hq_batched' here - because why not... Don't really need a separate injector for this.
b, s, c, h, w = state['hq'].shape
final_results['hq_batched'] = state['hq'].clone().permute(1,0,2,3,4).reshape(b*s, c, h, w)
if self.hq_batched_output_key is not None:
final_results[self.hq_batched_output_key] = state['hq'].clone().permute(1,0,2,3,4).reshape(b*s, c, h, w)
for k, v in results.items():
final_results[k] = torch.stack(v, dim=1)
final_results[k + "_batched"] = torch.cat(v[:s], dim=0) # Only include the original sequence - this output is generally used to compare against HQ.