Fix ddp bug

This commit is contained in:
James Betker 2021-06-13 10:25:23 -06:00
parent 3e3ad7825f
commit 1cd75dfd33

View File

@ -98,7 +98,8 @@ class BaseModel():
return save_path return save_path
def load_network(self, load_path, network, strict=True, pretrain_base_path=None): def load_network(self, load_path, network, strict=True, pretrain_base_path=None):
#if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): # Sometimes networks are passed in as DDP modules, we want the raw parameters.
if hasattr(network, 'module'):
network = network.module network = network.module
load_net = torch.load(load_path) load_net = torch.load(load_path)