diff --git a/codes/trainer/base_model.py b/codes/trainer/base_model.py index b5563341..969f9ca5 100644 --- a/codes/trainer/base_model.py +++ b/codes/trainer/base_model.py @@ -106,7 +106,8 @@ class BaseModel(): # Sometimes networks are passed in as DDP modules, we want the raw parameters. if hasattr(network, 'module'): network = network.module - load_net = torch.load(load_path) + load_net = torch.load(load_path, + map_location=lambda storage, loc: storage.cuda(self.rank if self.rank != -1 else 0)) # Support loading torch.save()s for whole models as well as just state_dicts. if 'state_dict' in load_net: