diff --git a/codes/train.py b/codes/train.py index 8d7551d9..2846fe09 100644 --- a/codes/train.py +++ b/codes/train.py @@ -17,7 +17,7 @@ from trainer.ExtensibleTrainer import ExtensibleTrainer from time import time from datetime import datetime -from utils.util import opt_get +from utils.util import opt_get, map_cuda_to_correct_device def init_dist(backend, **kwargs): @@ -43,9 +43,7 @@ class Trainer: #### loading resume state if exists if opt['path'].get('resume_state', None): # distributed resuming: all load into default GPU - device_id = torch.cuda.current_device() - resume_state = torch.load(opt['path']['resume_state'], - map_location=lambda storage, loc: storage.cuda(device_id)) + resume_state = torch.load(opt['path']['resume_state'], map_location=map_cuda_to_correct_device) else: resume_state = None diff --git a/codes/trainer/base_model.py b/codes/trainer/base_model.py index 969f9ca5..329140ad 100644 --- a/codes/trainer/base_model.py +++ b/codes/trainer/base_model.py @@ -106,8 +106,7 @@ class BaseModel(): # Sometimes networks are passed in as DDP modules, we want the raw parameters. if hasattr(network, 'module'): network = network.module - load_net = torch.load(load_path, - map_location=lambda storage, loc: storage.cuda(self.rank if self.rank != -1 else 0)) + load_net = torch.load(load_path, map_location=utils.util.map_cuda_to_correct_device) # Support loading torch.save()s for whole models as well as just state_dicts. if 'state_dict' in load_net: diff --git a/codes/utils/util.py b/codes/utils/util.py index 75b6cf09..ff5267b7 100644 --- a/codes/utils/util.py +++ b/codes/utils/util.py @@ -488,3 +488,11 @@ def load_model_from_config(cfg_file=None, model_name=None, dev='cuda', also_load print(f"Loading from {load_path}") model.load_state_dict(torch.load(load_path), strict=strict_load) return model + + +# Mapper for torch.load() that maps cuda devices to the correct CUDA device, but leaves CPU devices alone. +def map_cuda_to_correct_device(storage, loc): + if str(loc).startswith('cuda'): + return storage.cuda(torch.cuda.current_device()) + else: + return storage.cpu()