Fix more distributed bugs
This commit is contained in:
parent
b36ba0460c
commit
1eb516d686
|
@ -298,6 +298,8 @@ class ExtensibleTrainer(BaseModel):
|
|||
def load(self):
|
||||
for netdict in [self.netsG, self.netsD]:
|
||||
for name, net in netdict.items():
|
||||
if not self.opt['networks'][name]['trainable']:
|
||||
continue
|
||||
load_path = self.opt['path']['pretrain_model_%s' % (name,)]
|
||||
if load_path is not None:
|
||||
if self.rank <= 0:
|
||||
|
|
|
@ -7,6 +7,7 @@ import torch.nn.functional as F
|
|||
import os
|
||||
import os.path as osp
|
||||
import torchvision
|
||||
import torch.distributed as dist
|
||||
|
||||
def create_teco_loss(opt, env):
|
||||
type = opt['type']
|
||||
|
@ -123,6 +124,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
return {self.output: results}
|
||||
|
||||
def produce_teco_visual_debugs(self, gen_input, it):
|
||||
if dist.get_rank() > 0:
|
||||
return
|
||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
torchvision.utils.save_image(gen_input[:, :3], osp.join(base_path, "%s_img.png" % (it,)))
|
||||
|
@ -192,6 +195,8 @@ class TecoGanLoss(ConfigurableLoss):
|
|||
return l_total
|
||||
|
||||
def produce_teco_visual_debugs(self, sext, lbl, it):
|
||||
if dist.get_rank() > 0:
|
||||
return
|
||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl)
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c']
|
||||
|
@ -220,8 +225,10 @@ class PingPongLoss(ConfigurableLoss):
|
|||
return l_total
|
||||
|
||||
def produce_teco_visual_debugs(self, imglist):
|
||||
if dist.get_rank() > 0:
|
||||
return
|
||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
assert isinstance(imglist, list)
|
||||
for i, img in enumerate(imglist):
|
||||
torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, )))
|
||||
torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, )))
|
||||
|
|
Loading…
Reference in New Issue
Block a user