forked from mrq/DL-Art-School
Fix ddp bug
This commit is contained in:
parent
3e3ad7825f
commit
1cd75dfd33
|
@ -98,8 +98,9 @@ class BaseModel():
|
|||
return save_path
|
||||
|
||||
def load_network(self, load_path, network, strict=True, pretrain_base_path=None):
|
||||
#if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
||||
network = network.module
|
||||
# Sometimes networks are passed in as DDP modules, we want the raw parameters.
|
||||
if hasattr(network, 'module'):
|
||||
network = network.module
|
||||
load_net = torch.load(load_path)
|
||||
|
||||
# Support loading torch.save()s for whole models as well as just state_dicts.
|
||||
|
|
Loading…
Reference in New Issue
Block a user