Move to apex distributeddataparallel and add switch all_reduce
Torch's distributed_data_parallel is missing "delay_allreduce", which is necessary to get gradient checkpointing to work with recurrent models.
This commit is contained in:
parent
c174ac0fd5
commit
fba29d7dcc
|
@ -3,7 +3,8 @@ import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from apex import amp
|
from apex import amp
|
||||||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
from apex.parallel import DistributedDataParallel
|
||||||
|
from torch.nn.parallel import DataParallel
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
import models.lr_scheduler as lr_scheduler
|
import models.lr_scheduler as lr_scheduler
|
||||||
|
@ -107,9 +108,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
dnets = []
|
dnets = []
|
||||||
for anet in amp_nets:
|
for anet in amp_nets:
|
||||||
if opt['dist']:
|
if opt['dist']:
|
||||||
dnet = DistributedDataParallel(anet,
|
dnet = DistributedDataParallel(anet, delay_allreduce=True)
|
||||||
device_ids=[torch.cuda.current_device()],
|
|
||||||
find_unused_parameters=False)
|
|
||||||
else:
|
else:
|
||||||
dnet = DataParallel(anet)
|
dnet = DataParallel(anet)
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
|
@ -313,4 +312,4 @@ class ExtensibleTrainer(BaseModel):
|
||||||
|
|
||||||
def force_restore_swapout(self):
|
def force_restore_swapout(self):
|
||||||
# Legacy method. Do nothing.
|
# Legacy method. Do nothing.
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -476,6 +476,10 @@ class StackedSwitchGenerator5Layer(nn.Module):
|
||||||
|
|
||||||
def update_for_step(self, step, experiments_path='.'):
|
def update_for_step(self, step, experiments_path='.'):
|
||||||
if self.attentions:
|
if self.attentions:
|
||||||
|
# All-reduce the attention norm.
|
||||||
|
for sw in self.switches:
|
||||||
|
sw.switch.reduce_norm_params()
|
||||||
|
|
||||||
temp = max(1, 1 + self.init_temperature *
|
temp = max(1, 1 + self.init_temperature *
|
||||||
(self.final_temperature_step - step) / self.final_temperature_step)
|
(self.final_temperature_step - step) / self.final_temperature_step)
|
||||||
self.set_temperature(temp)
|
self.set_temperature(temp)
|
||||||
|
@ -496,4 +500,4 @@ class StackedSwitchGenerator5Layer(nn.Module):
|
||||||
for i in range(len(means)):
|
for i in range(len(means)):
|
||||||
val["switch_%i_specificity" % (i,)] = means[i]
|
val["switch_%i_specificity" % (i,)] = means[i]
|
||||||
val["switch_%i_histogram" % (i,)] = hists[i]
|
val["switch_%i_histogram" % (i,)] = hists[i]
|
||||||
return val
|
return val
|
||||||
|
|
|
@ -2,7 +2,7 @@ import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.parallel import DistributedDataParallel
|
from apex.parallel import DistributedDataParallel
|
||||||
import utils.util
|
import utils.util
|
||||||
from apex import amp
|
from apex import amp
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user