Allow dist_backend to be specified in options

This commit is contained in:
James Betker 2021-01-09 20:54:32 -07:00
parent 14a868e8e6
commit 48f0d8964b

View File

@ -108,11 +108,13 @@ class ExtensibleTrainer(BaseModel):
all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()] all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
for anet in all_networks: for anet in all_networks:
if opt['dist']: if opt['dist']:
if opt['dist_backend'] == 'apex':
# Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing. # Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing.
from apex.parallel import DistributedDataParallel from apex.parallel import DistributedDataParallel
dnet = DistributedDataParallel(anet, delay_allreduce=True) dnet = DistributedDataParallel(anet, delay_allreduce=True)
#from torch.nn.parallel.distributed import DistributedDataParallel else:
#dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], find_unused_parameters=True) from torch.nn.parallel.distributed import DistributedDataParallel
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()])
else: else:
dnet = DataParallel(anet, device_ids=opt['gpu_ids']) dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
if self.is_train: if self.is_train: