Fix tecogan_losses errors

This commit is contained in:
James Betker 2020-10-10 20:30:14 -06:00
parent 3a5b23b9f7
commit f99812e14d

View File

@ -96,7 +96,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float()) recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
input[self.recurrent_index] = recurrent_input input[self.recurrent_index] = recurrent_input
if self.env['step'] % 50 == 0: if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index) self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
debug_index += 1 debug_index += 1
gen_out = gen(*input) gen_out = gen(*input)
if isinstance(gen_out, torch.Tensor): if isinstance(gen_out, torch.Tensor):
@ -117,7 +117,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
input[self.recurrent_index input[self.recurrent_index
] = recurrent_input ] = recurrent_input
if self.env['step'] % 50 == 0: if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index) self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.self.recurrent_index], debug_index)
debug_index += 1 debug_index += 1
gen_out = gen(*input) gen_out = gen(*input)
if isinstance(gen_out, torch.Tensor): if isinstance(gen_out, torch.Tensor):
@ -127,13 +127,13 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
return {self.output: results} return {self.output: results}
def produce_teco_visual_debugs(self, gen_input, it): def produce_teco_visual_debugs(self, gen_input, gen_recurrent, it):
if self.env['rank'] > 0: if self.env['rank'] > 0:
return return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step'])) base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
os.makedirs(base_path, exist_ok=True) os.makedirs(base_path, exist_ok=True)
torchvision.utils.save_image(gen_input[:, :3], osp.join(base_path, "%s_img.png" % (it,))) torchvision.utils.save_image(gen_input, osp.join(base_path, "%s_img.png" % (it,)))
torchvision.utils.save_image(gen_input[:, 3:], osp.join(base_path, "%s_recurrent.png" % (it,))) torchvision.utils.save_image(gen_recurrent, osp.join(base_path, "%s_recurrent.png" % (it,)))
class FlowAdjustment(Injector): class FlowAdjustment(Injector):