Fix tecogan_losses errors
This commit is contained in:
parent
3a5b23b9f7
commit
f99812e14d
|
@ -96,7 +96,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
|
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
|
||||||
input[self.recurrent_index] = recurrent_input
|
input[self.recurrent_index] = recurrent_input
|
||||||
if self.env['step'] % 50 == 0:
|
if self.env['step'] % 50 == 0:
|
||||||
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
|
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
|
||||||
debug_index += 1
|
debug_index += 1
|
||||||
gen_out = gen(*input)
|
gen_out = gen(*input)
|
||||||
if isinstance(gen_out, torch.Tensor):
|
if isinstance(gen_out, torch.Tensor):
|
||||||
|
@ -117,7 +117,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
input[self.recurrent_index
|
input[self.recurrent_index
|
||||||
] = recurrent_input
|
] = recurrent_input
|
||||||
if self.env['step'] % 50 == 0:
|
if self.env['step'] % 50 == 0:
|
||||||
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
|
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.self.recurrent_index], debug_index)
|
||||||
debug_index += 1
|
debug_index += 1
|
||||||
gen_out = gen(*input)
|
gen_out = gen(*input)
|
||||||
if isinstance(gen_out, torch.Tensor):
|
if isinstance(gen_out, torch.Tensor):
|
||||||
|
@ -127,13 +127,13 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
|
|
||||||
return {self.output: results}
|
return {self.output: results}
|
||||||
|
|
||||||
def produce_teco_visual_debugs(self, gen_input, it):
|
def produce_teco_visual_debugs(self, gen_input, gen_recurrent, it):
|
||||||
if self.env['rank'] > 0:
|
if self.env['rank'] > 0:
|
||||||
return
|
return
|
||||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
|
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
|
||||||
os.makedirs(base_path, exist_ok=True)
|
os.makedirs(base_path, exist_ok=True)
|
||||||
torchvision.utils.save_image(gen_input[:, :3], osp.join(base_path, "%s_img.png" % (it,)))
|
torchvision.utils.save_image(gen_input, osp.join(base_path, "%s_img.png" % (it,)))
|
||||||
torchvision.utils.save_image(gen_input[:, 3:], osp.join(base_path, "%s_recurrent.png" % (it,)))
|
torchvision.utils.save_image(gen_recurrent, osp.join(base_path, "%s_recurrent.png" % (it,)))
|
||||||
|
|
||||||
|
|
||||||
class FlowAdjustment(Injector):
|
class FlowAdjustment(Injector):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user