load model state dicts into the correct device
it's not clear to me that this will make a huge difference, but it's a good idea anyways
This commit is contained in:
parent
3e16c509f6
commit
33511243d5
|
@ -106,7 +106,8 @@ class BaseModel():
|
||||||
# Sometimes networks are passed in as DDP modules, we want the raw parameters.
|
# Sometimes networks are passed in as DDP modules, we want the raw parameters.
|
||||||
if hasattr(network, 'module'):
|
if hasattr(network, 'module'):
|
||||||
network = network.module
|
network = network.module
|
||||||
load_net = torch.load(load_path)
|
load_net = torch.load(load_path,
|
||||||
|
map_location=lambda storage, loc: storage.cuda(self.rank if self.rank != -1 else 0))
|
||||||
|
|
||||||
# Support loading torch.save()s for whole models as well as just state_dicts.
|
# Support loading torch.save()s for whole models as well as just state_dicts.
|
||||||
if 'state_dict' in load_net:
|
if 'state_dict' in load_net:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user