From 1cd75dfd333eaf8a72eff6225f814e5c5e2c988b Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 13 Jun 2021 10:25:23 -0600 Subject: [PATCH] Fix ddp bug --- codes/trainer/base_model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/codes/trainer/base_model.py b/codes/trainer/base_model.py index 9dacf96d..ab655b33 100644 --- a/codes/trainer/base_model.py +++ b/codes/trainer/base_model.py @@ -98,8 +98,9 @@ class BaseModel(): return save_path def load_network(self, load_path, network, strict=True, pretrain_base_path=None): - #if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): - network = network.module + # Sometimes networks are passed in as DDP modules, we want the raw parameters. + if hasattr(network, 'module'): + network = network.module load_net = torch.load(load_path) # Support loading torch.save()s for whole models as well as just state_dicts.