diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index fb8ac6a1..3fffe510 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -108,11 +108,13 @@ class ExtensibleTrainer(BaseModel): all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()] for anet in all_networks: if opt['dist']: - # Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing. - from apex.parallel import DistributedDataParallel - dnet = DistributedDataParallel(anet, delay_allreduce=True) - #from torch.nn.parallel.distributed import DistributedDataParallel - #dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], find_unused_parameters=True) + if opt['dist_backend'] == 'apex': + # Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing. + from apex.parallel import DistributedDataParallel + dnet = DistributedDataParallel(anet, delay_allreduce=True) + else: + from torch.nn.parallel.distributed import DistributedDataParallel + dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()]) else: dnet = DataParallel(anet, device_ids=opt['gpu_ids']) if self.is_train: