forked from mrq/DL-Art-School
load model state dicts into the correct device
it's not clear to me that this will make a huge difference, but it's a good idea anyways
This commit is contained in:
parent
3e16c509f6
commit
33511243d5
|
@ -106,7 +106,8 @@ class BaseModel():
|
|||
# Sometimes networks are passed in as DDP modules, we want the raw parameters.
|
||||
if hasattr(network, 'module'):
|
||||
network = network.module
|
||||
load_net = torch.load(load_path)
|
||||
load_net = torch.load(load_path,
|
||||
map_location=lambda storage, loc: storage.cuda(self.rank if self.rank != -1 else 0))
|
||||
|
||||
# Support loading torch.save()s for whole models as well as just state_dicts.
|
||||
if 'state_dict' in load_net:
|
||||
|
|
Loading…
Reference in New Issue
Block a user