forked from mrq/DL-Art-School
Go back to torch's DDP
Apex was having some weird crashing issues.
This commit is contained in:
parent
d856378b2e
commit
d1c63ae339
|
@ -3,9 +3,9 @@ import os
|
|||
|
||||
import torch
|
||||
from apex import amp
|
||||
from apex.parallel import DistributedDataParallel
|
||||
from torch.nn.parallel import DataParallel
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
import models.lr_scheduler as lr_scheduler
|
||||
import models.networks as networks
|
||||
|
@ -109,7 +109,9 @@ class ExtensibleTrainer(BaseModel):
|
|||
dnets = []
|
||||
for anet in amp_nets:
|
||||
if opt['dist']:
|
||||
dnet = DistributedDataParallel(anet, delay_allreduce=True)
|
||||
dnet = DistributedDataParallel(anet,
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
find_unused_parameters=False)
|
||||
else:
|
||||
dnet = DataParallel(anet)
|
||||
if self.is_train:
|
||||
|
|
|
@ -2,7 +2,8 @@ import os
|
|||
from collections import OrderedDict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from apex.parallel import DistributedDataParallel
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
import utils.util
|
||||
from apex import amp
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user