Revise device mapping
This commit is contained in:
parent
33511243d5
commit
49edffb6ad
|
@ -17,7 +17,7 @@ from trainer.ExtensibleTrainer import ExtensibleTrainer
|
|||
from time import time
|
||||
from datetime import datetime
|
||||
|
||||
from utils.util import opt_get
|
||||
from utils.util import opt_get, map_cuda_to_correct_device
|
||||
|
||||
|
||||
def init_dist(backend, **kwargs):
|
||||
|
@ -43,9 +43,7 @@ class Trainer:
|
|||
#### loading resume state if exists
|
||||
if opt['path'].get('resume_state', None):
|
||||
# distributed resuming: all load into default GPU
|
||||
device_id = torch.cuda.current_device()
|
||||
resume_state = torch.load(opt['path']['resume_state'],
|
||||
map_location=lambda storage, loc: storage.cuda(device_id))
|
||||
resume_state = torch.load(opt['path']['resume_state'], map_location=map_cuda_to_correct_device)
|
||||
else:
|
||||
resume_state = None
|
||||
|
||||
|
|
|
@ -106,8 +106,7 @@ class BaseModel():
|
|||
# Sometimes networks are passed in as DDP modules, we want the raw parameters.
|
||||
if hasattr(network, 'module'):
|
||||
network = network.module
|
||||
load_net = torch.load(load_path,
|
||||
map_location=lambda storage, loc: storage.cuda(self.rank if self.rank != -1 else 0))
|
||||
load_net = torch.load(load_path, map_location=utils.util.map_cuda_to_correct_device)
|
||||
|
||||
# Support loading torch.save()s for whole models as well as just state_dicts.
|
||||
if 'state_dict' in load_net:
|
||||
|
|
|
@ -488,3 +488,11 @@ def load_model_from_config(cfg_file=None, model_name=None, dev='cuda', also_load
|
|||
print(f"Loading from {load_path}")
|
||||
model.load_state_dict(torch.load(load_path), strict=strict_load)
|
||||
return model
|
||||
|
||||
|
||||
# Mapper for torch.load() that maps cuda devices to the correct CUDA device, but leaves CPU devices alone.
|
||||
def map_cuda_to_correct_device(storage, loc):
|
||||
if str(loc).startswith('cuda'):
|
||||
return storage.cuda(torch.cuda.current_device())
|
||||
else:
|
||||
return storage.cpu()
|
||||
|
|
Loading…
Reference in New Issue
Block a user