forked from mrq/DL-Art-School
Go back to torch's DDP
Apex was having some weird crashing issues.
This commit is contained in:
parent
d856378b2e
commit
d1c63ae339
|
@ -3,9 +3,9 @@ import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from apex import amp
|
from apex import amp
|
||||||
from apex.parallel import DistributedDataParallel
|
|
||||||
from torch.nn.parallel import DataParallel
|
from torch.nn.parallel import DataParallel
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||||
|
|
||||||
import models.lr_scheduler as lr_scheduler
|
import models.lr_scheduler as lr_scheduler
|
||||||
import models.networks as networks
|
import models.networks as networks
|
||||||
|
@ -109,7 +109,9 @@ class ExtensibleTrainer(BaseModel):
|
||||||
dnets = []
|
dnets = []
|
||||||
for anet in amp_nets:
|
for anet in amp_nets:
|
||||||
if opt['dist']:
|
if opt['dist']:
|
||||||
dnet = DistributedDataParallel(anet, delay_allreduce=True)
|
dnet = DistributedDataParallel(anet,
|
||||||
|
device_ids=[torch.cuda.current_device()],
|
||||||
|
find_unused_parameters=False)
|
||||||
else:
|
else:
|
||||||
dnet = DataParallel(anet)
|
dnet = DataParallel(anet)
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
|
|
|
@ -2,7 +2,8 @@ import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from apex.parallel import DistributedDataParallel
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||||
|
|
||||||
import utils.util
|
import utils.util
|
||||||
from apex import amp
|
from apex import amp
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user