forked from mrq/DL-Art-School
Fix ddp bug
This commit is contained in:
parent
3e3ad7825f
commit
1cd75dfd33
|
@ -98,8 +98,9 @@ class BaseModel():
|
||||||
return save_path
|
return save_path
|
||||||
|
|
||||||
def load_network(self, load_path, network, strict=True, pretrain_base_path=None):
|
def load_network(self, load_path, network, strict=True, pretrain_base_path=None):
|
||||||
#if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
# Sometimes networks are passed in as DDP modules, we want the raw parameters.
|
||||||
network = network.module
|
if hasattr(network, 'module'):
|
||||||
|
network = network.module
|
||||||
load_net = torch.load(load_path)
|
load_net = torch.load(load_path)
|
||||||
|
|
||||||
# Support loading torch.save()s for whole models as well as just state_dicts.
|
# Support loading torch.save()s for whole models as well as just state_dicts.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user