diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 079da6e1..88e34671 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -298,6 +298,8 @@ class ExtensibleTrainer(BaseModel): def load(self): for netdict in [self.netsG, self.netsD]: for name, net in netdict.items(): + if not self.opt['networks'][name]['trainable']: + continue load_path = self.opt['path']['pretrain_model_%s' % (name,)] if load_path is not None: if self.rank <= 0: diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index ea8b1017..1c033c60 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -7,6 +7,7 @@ import torch.nn.functional as F import os import os.path as osp import torchvision +import torch.distributed as dist def create_teco_loss(opt, env): type = opt['type'] @@ -123,6 +124,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector): return {self.output: results} def produce_teco_visual_debugs(self, gen_input, it): + if dist.get_rank() > 0: + return base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step'])) os.makedirs(base_path, exist_ok=True) torchvision.utils.save_image(gen_input[:, :3], osp.join(base_path, "%s_img.png" % (it,))) @@ -192,6 +195,8 @@ class TecoGanLoss(ConfigurableLoss): return l_total def produce_teco_visual_debugs(self, sext, lbl, it): + if dist.get_rank() > 0: + return base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl) os.makedirs(base_path, exist_ok=True) lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c'] @@ -220,8 +225,10 @@ class PingPongLoss(ConfigurableLoss): return l_total def produce_teco_visual_debugs(self, imglist): + if dist.get_rank() > 0: + return base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step'])) os.makedirs(base_path, exist_ok=True) assert isinstance(imglist, list) for i, img in enumerate(imglist): - torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, ))) \ No newline at end of file + torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, )))