From d1c63ae33906c2071eadd530fc96f3c7223cba38 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 16 Oct 2020 20:47:35 -0600 Subject: [PATCH] Go back to torch's DDP Apex was having some weird crashing issues. --- codes/models/ExtensibleTrainer.py | 6 ++++-- codes/models/base_model.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 314f0b07..b12eb7fe 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -3,9 +3,9 @@ import os import torch from apex import amp -from apex.parallel import DistributedDataParallel from torch.nn.parallel import DataParallel import torch.nn as nn +from torch.nn.parallel.distributed import DistributedDataParallel import models.lr_scheduler as lr_scheduler import models.networks as networks @@ -109,7 +109,9 @@ class ExtensibleTrainer(BaseModel): dnets = [] for anet in amp_nets: if opt['dist']: - dnet = DistributedDataParallel(anet, delay_allreduce=True) + dnet = DistributedDataParallel(anet, + device_ids=[torch.cuda.current_device()], + find_unused_parameters=False) else: dnet = DataParallel(anet) if self.is_train: diff --git a/codes/models/base_model.py b/codes/models/base_model.py index 2672bfa5..04b6d9e3 100644 --- a/codes/models/base_model.py +++ b/codes/models/base_model.py @@ -2,7 +2,8 @@ import os from collections import OrderedDict import torch import torch.nn as nn -from apex.parallel import DistributedDataParallel +from torch.nn.parallel.distributed import DistributedDataParallel + import utils.util from apex import amp