Allow tecogan to support generators that only output a tensor (instead of a list)
This commit is contained in:
parent
969bcd9021
commit
c174ac0fd5
|
@ -95,6 +95,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
|
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
|
||||||
debug_index += 1
|
debug_index += 1
|
||||||
gen_out = gen(*input)
|
gen_out = gen(*input)
|
||||||
|
if isinstance(gen_out, torch.Tensor):
|
||||||
|
gen_out = [gen_out]
|
||||||
recurrent_input = gen_out[self.output_hq_index]
|
recurrent_input = gen_out[self.output_hq_index]
|
||||||
results.append(recurrent_input)
|
results.append(recurrent_input)
|
||||||
|
|
||||||
|
@ -113,6 +115,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
|
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
|
||||||
debug_index += 1
|
debug_index += 1
|
||||||
gen_out = gen(*input)
|
gen_out = gen(*input)
|
||||||
|
if isinstance(gen_out, torch.Tensor):
|
||||||
|
gen_out = [gen_out]
|
||||||
recurrent_input = gen_out[self.output_hq_index]
|
recurrent_input = gen_out[self.output_hq_index]
|
||||||
results.append(recurrent_input)
|
results.append(recurrent_input)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user