forked from mrq/DL-Art-School
Allow hq_batched_key to be specified
This commit is contained in:
parent
c47925ae34
commit
67bf55495b
|
@ -50,6 +50,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
self.first_inputs = opt['first_inputs'] if 'first_inputs' in opt.keys() else opt['in'] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'.
|
||||
self.do_backwards = opt['do_backwards'] if 'do_backwards' in opt.keys() else True
|
||||
self.hq_recurrent = opt['hq_recurrent'] if 'hq_recurrent' in opt.keys() else False # When True, recurrent_index is not touched for the first iteration, allowing you to specify what is fed in. When False, zeros are fed into the recurrent index.
|
||||
self.hq_batched_output_key = opt['hq_batched_key'] if 'hq_batched_key' in opt.keys() else None
|
||||
|
||||
def forward(self, state):
|
||||
gen = self.env['generators'][self.opt['generator']]
|
||||
|
@ -138,7 +139,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
final_results = {}
|
||||
# Include 'hq_batched' here - because why not... Don't really need a separate injector for this.
|
||||
b, s, c, h, w = state['hq'].shape
|
||||
final_results['hq_batched'] = state['hq'].clone().permute(1,0,2,3,4).reshape(b*s, c, h, w)
|
||||
if self.hq_batched_output_key is not None:
|
||||
final_results[self.hq_batched_output_key] = state['hq'].clone().permute(1,0,2,3,4).reshape(b*s, c, h, w)
|
||||
for k, v in results.items():
|
||||
final_results[k] = torch.stack(v, dim=1)
|
||||
final_results[k + "_batched"] = torch.cat(v[:s], dim=0) # Only include the original sequence - this output is generally used to compare against HQ.
|
||||
|
|
Loading…
Reference in New Issue
Block a user