Go back to torch's DDP

Apex was having some weird crashing issues.
This commit is contained in:
James Betker 2020-10-16 20:47:35 -06:00
parent d856378b2e
commit d1c63ae339
2 changed files with 6 additions and 3 deletions

View File

@ -3,9 +3,9 @@ import os
import torch import torch
from apex import amp from apex import amp
from apex.parallel import DistributedDataParallel
from torch.nn.parallel import DataParallel from torch.nn.parallel import DataParallel
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel.distributed import DistributedDataParallel
import models.lr_scheduler as lr_scheduler import models.lr_scheduler as lr_scheduler
import models.networks as networks import models.networks as networks
@ -109,7 +109,9 @@ class ExtensibleTrainer(BaseModel):
dnets = [] dnets = []
for anet in amp_nets: for anet in amp_nets:
if opt['dist']: if opt['dist']:
dnet = DistributedDataParallel(anet, delay_allreduce=True) dnet = DistributedDataParallel(anet,
device_ids=[torch.cuda.current_device()],
find_unused_parameters=False)
else: else:
dnet = DataParallel(anet) dnet = DataParallel(anet)
if self.is_train: if self.is_train:

View File

@ -2,7 +2,8 @@ import os
from collections import OrderedDict from collections import OrderedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
from apex.parallel import DistributedDataParallel from torch.nn.parallel.distributed import DistributedDataParallel
import utils.util import utils.util
from apex import amp from apex import amp