load model state dicts into the correct device

it's not clear to me that this will make a huge difference, but it's a good idea anyways
This commit is contained in:
James Betker 2022-01-24 14:40:09 -07:00
parent 3e16c509f6
commit 33511243d5

View File

@ -106,7 +106,8 @@ class BaseModel():
# Sometimes networks are passed in as DDP modules, we want the raw parameters. # Sometimes networks are passed in as DDP modules, we want the raw parameters.
if hasattr(network, 'module'): if hasattr(network, 'module'):
network = network.module network = network.module
load_net = torch.load(load_path) load_net = torch.load(load_path,
map_location=lambda storage, loc: storage.cuda(self.rank if self.rank != -1 else 0))
# Support loading torch.save()s for whole models as well as just state_dicts. # Support loading torch.save()s for whole models as well as just state_dicts.
if 'state_dict' in load_net: if 'state_dict' in load_net: