Go back to torch's DDP

Apex was having some weird crashing issues.
This commit is contained in:
James Betker 2020-10-16 20:47:35 -06:00
parent d856378b2e
commit d1c63ae339
2 changed files with 6 additions and 3 deletions

View File

@ -3,9 +3,9 @@ import os
import torch
from apex import amp
from apex.parallel import DistributedDataParallel
from torch.nn.parallel import DataParallel
import torch.nn as nn
from torch.nn.parallel.distributed import DistributedDataParallel
import models.lr_scheduler as lr_scheduler
import models.networks as networks
@ -109,7 +109,9 @@ class ExtensibleTrainer(BaseModel):
dnets = []
for anet in amp_nets:
if opt['dist']:
dnet = DistributedDataParallel(anet, delay_allreduce=True)
dnet = DistributedDataParallel(anet,
device_ids=[torch.cuda.current_device()],
find_unused_parameters=False)
else:
dnet = DataParallel(anet)
if self.is_train:

View File

@ -2,7 +2,8 @@ import os
from collections import OrderedDict
import torch
import torch.nn as nn
from apex.parallel import DistributedDataParallel
from torch.nn.parallel.distributed import DistributedDataParallel
import utils.util
from apex import amp