From 48f0d8964b04dcc23e06c7c99de8a470f912980c Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 9 Jan 2021 20:54:32 -0700 Subject: [PATCH] Allow dist_backend to be specified in options --- codes/trainer/ExtensibleTrainer.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index fb8ac6a1..3fffe510 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -108,11 +108,13 @@ class ExtensibleTrainer(BaseModel): all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()] for anet in all_networks: if opt['dist']: - # Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing. - from apex.parallel import DistributedDataParallel - dnet = DistributedDataParallel(anet, delay_allreduce=True) - #from torch.nn.parallel.distributed import DistributedDataParallel - #dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], find_unused_parameters=True) + if opt['dist_backend'] == 'apex': + # Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing. + from apex.parallel import DistributedDataParallel + dnet = DistributedDataParallel(anet, delay_allreduce=True) + else: + from torch.nn.parallel.distributed import DistributedDataParallel + dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()]) else: dnet = DataParallel(anet, device_ids=opt['gpu_ids']) if self.is_train: