diff --git a/codes/trainer/base_model.py b/codes/trainer/base_model.py index 9dacf96d..ab655b33 100644 --- a/codes/trainer/base_model.py +++ b/codes/trainer/base_model.py @@ -98,8 +98,9 @@ class BaseModel(): return save_path def load_network(self, load_path, network, strict=True, pretrain_base_path=None): - #if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): - network = network.module + # Sometimes networks are passed in as DDP modules, we want the raw parameters. + if hasattr(network, 'module'): + network = network.module load_net = torch.load(load_path) # Support loading torch.save()s for whole models as well as just state_dicts.