Move to apex distributeddataparallel and add switch all_reduce

Torch's distributed_data_parallel is missing "delay_allreduce", which is
necessary to get gradient checkpointing to work with recurrent models.
This commit is contained in:
James Betker 2020-10-08 11:20:05 -06:00
parent c174ac0fd5
commit fba29d7dcc
3 changed files with 10 additions and 7 deletions

View File

@ -3,7 +3,8 @@ import os
import torch import torch
from apex import amp from apex import amp
from torch.nn.parallel import DataParallel, DistributedDataParallel from apex.parallel import DistributedDataParallel
from torch.nn.parallel import DataParallel
import torch.nn as nn import torch.nn as nn
import models.lr_scheduler as lr_scheduler import models.lr_scheduler as lr_scheduler
@ -107,9 +108,7 @@ class ExtensibleTrainer(BaseModel):
dnets = [] dnets = []
for anet in amp_nets: for anet in amp_nets:
if opt['dist']: if opt['dist']:
dnet = DistributedDataParallel(anet, dnet = DistributedDataParallel(anet, delay_allreduce=True)
device_ids=[torch.cuda.current_device()],
find_unused_parameters=False)
else: else:
dnet = DataParallel(anet) dnet = DataParallel(anet)
if self.is_train: if self.is_train:
@ -313,4 +312,4 @@ class ExtensibleTrainer(BaseModel):
def force_restore_swapout(self): def force_restore_swapout(self):
# Legacy method. Do nothing. # Legacy method. Do nothing.
pass pass

View File

@ -476,6 +476,10 @@ class StackedSwitchGenerator5Layer(nn.Module):
def update_for_step(self, step, experiments_path='.'): def update_for_step(self, step, experiments_path='.'):
if self.attentions: if self.attentions:
# All-reduce the attention norm.
for sw in self.switches:
sw.switch.reduce_norm_params()
temp = max(1, 1 + self.init_temperature * temp = max(1, 1 + self.init_temperature *
(self.final_temperature_step - step) / self.final_temperature_step) (self.final_temperature_step - step) / self.final_temperature_step)
self.set_temperature(temp) self.set_temperature(temp)
@ -496,4 +500,4 @@ class StackedSwitchGenerator5Layer(nn.Module):
for i in range(len(means)): for i in range(len(means)):
val["switch_%i_specificity" % (i,)] = means[i] val["switch_%i_specificity" % (i,)] = means[i]
val["switch_%i_histogram" % (i,)] = hists[i] val["switch_%i_histogram" % (i,)] = hists[i]
return val return val

View File

@ -2,7 +2,7 @@ import os
from collections import OrderedDict from collections import OrderedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel from apex.parallel import DistributedDataParallel
import utils.util import utils.util
from apex import amp from apex import amp