From 33511243d51e41c973d626d5be4d5a87b3418b6f Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 24 Jan 2022 14:40:09 -0700 Subject: [PATCH] load model state dicts into the correct device it's not clear to me that this will make a huge difference, but it's a good idea anyways --- codes/trainer/base_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/codes/trainer/base_model.py b/codes/trainer/base_model.py index b5563341..969f9ca5 100644 --- a/codes/trainer/base_model.py +++ b/codes/trainer/base_model.py @@ -106,7 +106,8 @@ class BaseModel(): # Sometimes networks are passed in as DDP modules, we want the raw parameters. if hasattr(network, 'module'): network = network.module - load_net = torch.load(load_path) + load_net = torch.load(load_path, + map_location=lambda storage, loc: storage.cuda(self.rank if self.rank != -1 else 0)) # Support loading torch.save()s for whole models as well as just state_dicts. if 'state_dict' in load_net: