Move to apex distributeddataparallel and add switch all_reduce
Torch's distributed_data_parallel is missing "delay_allreduce", which is necessary to get gradient checkpointing to work with recurrent models.
This commit is contained in:
parent
c174ac0fd5
commit
fba29d7dcc
|
@ -3,7 +3,8 @@ import os
|
|||
|
||||
import torch
|
||||
from apex import amp
|
||||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||
from apex.parallel import DistributedDataParallel
|
||||
from torch.nn.parallel import DataParallel
|
||||
import torch.nn as nn
|
||||
|
||||
import models.lr_scheduler as lr_scheduler
|
||||
|
@ -107,9 +108,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
dnets = []
|
||||
for anet in amp_nets:
|
||||
if opt['dist']:
|
||||
dnet = DistributedDataParallel(anet,
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
find_unused_parameters=False)
|
||||
dnet = DistributedDataParallel(anet, delay_allreduce=True)
|
||||
else:
|
||||
dnet = DataParallel(anet)
|
||||
if self.is_train:
|
||||
|
@ -313,4 +312,4 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
def force_restore_swapout(self):
|
||||
# Legacy method. Do nothing.
|
||||
pass
|
||||
pass
|
||||
|
|
|
@ -476,6 +476,10 @@ class StackedSwitchGenerator5Layer(nn.Module):
|
|||
|
||||
def update_for_step(self, step, experiments_path='.'):
|
||||
if self.attentions:
|
||||
# All-reduce the attention norm.
|
||||
for sw in self.switches:
|
||||
sw.switch.reduce_norm_params()
|
||||
|
||||
temp = max(1, 1 + self.init_temperature *
|
||||
(self.final_temperature_step - step) / self.final_temperature_step)
|
||||
self.set_temperature(temp)
|
||||
|
@ -496,4 +500,4 @@ class StackedSwitchGenerator5Layer(nn.Module):
|
|||
for i in range(len(means)):
|
||||
val["switch_%i_specificity" % (i,)] = means[i]
|
||||
val["switch_%i_histogram" % (i,)] = hists[i]
|
||||
return val
|
||||
return val
|
||||
|
|
|
@ -2,7 +2,7 @@ import os
|
|||
from collections import OrderedDict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from apex.parallel import DistributedDataParallel
|
||||
import utils.util
|
||||
from apex import amp
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user