Revise device mapping

This commit is contained in:
James Betker 2022-01-24 15:08:13 -07:00
parent 33511243d5
commit 49edffb6ad
3 changed files with 11 additions and 6 deletions

View File

@ -17,7 +17,7 @@ from trainer.ExtensibleTrainer import ExtensibleTrainer
from time import time from time import time
from datetime import datetime from datetime import datetime
from utils.util import opt_get from utils.util import opt_get, map_cuda_to_correct_device
def init_dist(backend, **kwargs): def init_dist(backend, **kwargs):
@ -43,9 +43,7 @@ class Trainer:
#### loading resume state if exists #### loading resume state if exists
if opt['path'].get('resume_state', None): if opt['path'].get('resume_state', None):
# distributed resuming: all load into default GPU # distributed resuming: all load into default GPU
device_id = torch.cuda.current_device() resume_state = torch.load(opt['path']['resume_state'], map_location=map_cuda_to_correct_device)
resume_state = torch.load(opt['path']['resume_state'],
map_location=lambda storage, loc: storage.cuda(device_id))
else: else:
resume_state = None resume_state = None

View File

@ -106,8 +106,7 @@ class BaseModel():
# Sometimes networks are passed in as DDP modules, we want the raw parameters. # Sometimes networks are passed in as DDP modules, we want the raw parameters.
if hasattr(network, 'module'): if hasattr(network, 'module'):
network = network.module network = network.module
load_net = torch.load(load_path, load_net = torch.load(load_path, map_location=utils.util.map_cuda_to_correct_device)
map_location=lambda storage, loc: storage.cuda(self.rank if self.rank != -1 else 0))
# Support loading torch.save()s for whole models as well as just state_dicts. # Support loading torch.save()s for whole models as well as just state_dicts.
if 'state_dict' in load_net: if 'state_dict' in load_net:

View File

@ -488,3 +488,11 @@ def load_model_from_config(cfg_file=None, model_name=None, dev='cuda', also_load
print(f"Loading from {load_path}") print(f"Loading from {load_path}")
model.load_state_dict(torch.load(load_path), strict=strict_load) model.load_state_dict(torch.load(load_path), strict=strict_load)
return model return model
# Mapper for torch.load() that maps cuda devices to the correct CUDA device, but leaves CPU devices alone.
def map_cuda_to_correct_device(storage, loc):
if str(loc).startswith('cuda'):
return storage.cuda(torch.cuda.current_device())
else:
return storage.cpu()