forked from mrq/DL-Art-School
Allow dist_backend to be specified in options
This commit is contained in:
parent
14a868e8e6
commit
48f0d8964b
|
@ -108,11 +108,13 @@ class ExtensibleTrainer(BaseModel):
|
||||||
all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
|
all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
|
||||||
for anet in all_networks:
|
for anet in all_networks:
|
||||||
if opt['dist']:
|
if opt['dist']:
|
||||||
|
if opt['dist_backend'] == 'apex':
|
||||||
# Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing.
|
# Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing.
|
||||||
from apex.parallel import DistributedDataParallel
|
from apex.parallel import DistributedDataParallel
|
||||||
dnet = DistributedDataParallel(anet, delay_allreduce=True)
|
dnet = DistributedDataParallel(anet, delay_allreduce=True)
|
||||||
#from torch.nn.parallel.distributed import DistributedDataParallel
|
else:
|
||||||
#dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], find_unused_parameters=True)
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||||
|
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()])
|
||||||
else:
|
else:
|
||||||
dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
|
dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user