forked from mrq/DL-Art-School
Allow swapping to torch DDP as needed in code
This commit is contained in:
parent
66cbae8731
commit
c203cee31e
|
@ -108,6 +108,8 @@ class ExtensibleTrainer(BaseModel):
|
|||
# Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing.
|
||||
from apex.parallel import DistributedDataParallel
|
||||
dnet = DistributedDataParallel(anet, delay_allreduce=True)
|
||||
#from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
#dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], find_unused_parameters=True)
|
||||
else:
|
||||
dnet = DataParallel(anet, device_ids=opt['gpu_ids'])
|
||||
if self.is_train:
|
||||
|
|
Loading…
Reference in New Issue
Block a user