diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 314f0b07..b12eb7fe 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -3,9 +3,9 @@ import os import torch from apex import amp -from apex.parallel import DistributedDataParallel from torch.nn.parallel import DataParallel import torch.nn as nn +from torch.nn.parallel.distributed import DistributedDataParallel import models.lr_scheduler as lr_scheduler import models.networks as networks @@ -109,7 +109,9 @@ class ExtensibleTrainer(BaseModel): dnets = [] for anet in amp_nets: if opt['dist']: - dnet = DistributedDataParallel(anet, delay_allreduce=True) + dnet = DistributedDataParallel(anet, + device_ids=[torch.cuda.current_device()], + find_unused_parameters=False) else: dnet = DataParallel(anet) if self.is_train: diff --git a/codes/models/base_model.py b/codes/models/base_model.py index 2672bfa5..04b6d9e3 100644 --- a/codes/models/base_model.py +++ b/codes/models/base_model.py @@ -2,7 +2,8 @@ import os from collections import OrderedDict import torch import torch.nn as nn -from apex.parallel import DistributedDataParallel +from torch.nn.parallel.distributed import DistributedDataParallel + import utils.util from apex import amp