Fix more distributed bugs

This commit is contained in:
James Betker 2020-10-08 14:32:45 -06:00
parent b36ba0460c
commit 1eb516d686
2 changed files with 10 additions and 1 deletions

View File

@ -298,6 +298,8 @@ class ExtensibleTrainer(BaseModel):
def load(self):
for netdict in [self.netsG, self.netsD]:
for name, net in netdict.items():
if not self.opt['networks'][name]['trainable']:
continue
load_path = self.opt['path']['pretrain_model_%s' % (name,)]
if load_path is not None:
if self.rank <= 0:

View File

@ -7,6 +7,7 @@ import torch.nn.functional as F
import os
import os.path as osp
import torchvision
import torch.distributed as dist
def create_teco_loss(opt, env):
type = opt['type']
@ -123,6 +124,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
return {self.output: results}
def produce_teco_visual_debugs(self, gen_input, it):
if dist.get_rank() > 0:
return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
os.makedirs(base_path, exist_ok=True)
torchvision.utils.save_image(gen_input[:, :3], osp.join(base_path, "%s_img.png" % (it,)))
@ -192,6 +195,8 @@ class TecoGanLoss(ConfigurableLoss):
return l_total
def produce_teco_visual_debugs(self, sext, lbl, it):
if dist.get_rank() > 0:
return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl)
os.makedirs(base_path, exist_ok=True)
lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c']
@ -220,6 +225,8 @@ class PingPongLoss(ConfigurableLoss):
return l_total
def produce_teco_visual_debugs(self, imglist):
if dist.get_rank() > 0:
return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
os.makedirs(base_path, exist_ok=True)
assert isinstance(imglist, list)