Allow tecogan to support generators that only output a tensor (instead of a list)

This commit is contained in:
James Betker 2020-10-08 09:26:25 -06:00
parent 969bcd9021
commit c174ac0fd5

View File

@ -95,6 +95,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index) self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
debug_index += 1 debug_index += 1
gen_out = gen(*input) gen_out = gen(*input)
if isinstance(gen_out, torch.Tensor):
gen_out = [gen_out]
recurrent_input = gen_out[self.output_hq_index] recurrent_input = gen_out[self.output_hq_index]
results.append(recurrent_input) results.append(recurrent_input)
@ -113,6 +115,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index) self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
debug_index += 1 debug_index += 1
gen_out = gen(*input) gen_out = gen(*input)
if isinstance(gen_out, torch.Tensor):
gen_out = [gen_out]
recurrent_input = gen_out[self.output_hq_index] recurrent_input = gen_out[self.output_hq_index]
results.append(recurrent_input) results.append(recurrent_input)