Fix more distributed bugs
This commit is contained in:
parent
b36ba0460c
commit
1eb516d686
|
@ -298,6 +298,8 @@ class ExtensibleTrainer(BaseModel):
|
||||||
def load(self):
|
def load(self):
|
||||||
for netdict in [self.netsG, self.netsD]:
|
for netdict in [self.netsG, self.netsD]:
|
||||||
for name, net in netdict.items():
|
for name, net in netdict.items():
|
||||||
|
if not self.opt['networks'][name]['trainable']:
|
||||||
|
continue
|
||||||
load_path = self.opt['path']['pretrain_model_%s' % (name,)]
|
load_path = self.opt['path']['pretrain_model_%s' % (name,)]
|
||||||
if load_path is not None:
|
if load_path is not None:
|
||||||
if self.rank <= 0:
|
if self.rank <= 0:
|
||||||
|
|
|
@ -7,6 +7,7 @@ import torch.nn.functional as F
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import torchvision
|
import torchvision
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
def create_teco_loss(opt, env):
|
def create_teco_loss(opt, env):
|
||||||
type = opt['type']
|
type = opt['type']
|
||||||
|
@ -123,6 +124,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
return {self.output: results}
|
return {self.output: results}
|
||||||
|
|
||||||
def produce_teco_visual_debugs(self, gen_input, it):
|
def produce_teco_visual_debugs(self, gen_input, it):
|
||||||
|
if dist.get_rank() > 0:
|
||||||
|
return
|
||||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
|
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
|
||||||
os.makedirs(base_path, exist_ok=True)
|
os.makedirs(base_path, exist_ok=True)
|
||||||
torchvision.utils.save_image(gen_input[:, :3], osp.join(base_path, "%s_img.png" % (it,)))
|
torchvision.utils.save_image(gen_input[:, :3], osp.join(base_path, "%s_img.png" % (it,)))
|
||||||
|
@ -192,6 +195,8 @@ class TecoGanLoss(ConfigurableLoss):
|
||||||
return l_total
|
return l_total
|
||||||
|
|
||||||
def produce_teco_visual_debugs(self, sext, lbl, it):
|
def produce_teco_visual_debugs(self, sext, lbl, it):
|
||||||
|
if dist.get_rank() > 0:
|
||||||
|
return
|
||||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl)
|
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl)
|
||||||
os.makedirs(base_path, exist_ok=True)
|
os.makedirs(base_path, exist_ok=True)
|
||||||
lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c']
|
lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c']
|
||||||
|
@ -220,6 +225,8 @@ class PingPongLoss(ConfigurableLoss):
|
||||||
return l_total
|
return l_total
|
||||||
|
|
||||||
def produce_teco_visual_debugs(self, imglist):
|
def produce_teco_visual_debugs(self, imglist):
|
||||||
|
if dist.get_rank() > 0:
|
||||||
|
return
|
||||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
|
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
|
||||||
os.makedirs(base_path, exist_ok=True)
|
os.makedirs(base_path, exist_ok=True)
|
||||||
assert isinstance(imglist, list)
|
assert isinstance(imglist, list)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user