Fix ddp bug

This commit is contained in:
James Betker 2021-06-13 10:25:23 -06:00
parent 3e3ad7825f
commit 1cd75dfd33

View File

@ -98,8 +98,9 @@ class BaseModel():
return save_path
def load_network(self, load_path, network, strict=True, pretrain_base_path=None):
#if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
network = network.module
# Sometimes networks are passed in as DDP modules, we want the raw parameters.
if hasattr(network, 'module'):
network = network.module
load_net = torch.load(load_path)
# Support loading torch.save()s for whole models as well as just state_dicts.