Move to apex distributeddataparallel and add switch all_reduce

Torch's distributed_data_parallel is missing "delay_allreduce", which is
necessary to get gradient checkpointing to work with recurrent models.
This commit is contained in:
James Betker 2020-10-08 11:20:05 -06:00
parent c174ac0fd5
commit fba29d7dcc
3 changed files with 10 additions and 7 deletions

View File

@ -3,7 +3,8 @@ import os
import torch
from apex import amp
from torch.nn.parallel import DataParallel, DistributedDataParallel
from apex.parallel import DistributedDataParallel
from torch.nn.parallel import DataParallel
import torch.nn as nn
import models.lr_scheduler as lr_scheduler
@ -107,9 +108,7 @@ class ExtensibleTrainer(BaseModel):
dnets = []
for anet in amp_nets:
if opt['dist']:
dnet = DistributedDataParallel(anet,
device_ids=[torch.cuda.current_device()],
find_unused_parameters=False)
dnet = DistributedDataParallel(anet, delay_allreduce=True)
else:
dnet = DataParallel(anet)
if self.is_train:

View File

@ -476,6 +476,10 @@ class StackedSwitchGenerator5Layer(nn.Module):
def update_for_step(self, step, experiments_path='.'):
if self.attentions:
# All-reduce the attention norm.
for sw in self.switches:
sw.switch.reduce_norm_params()
temp = max(1, 1 + self.init_temperature *
(self.final_temperature_step - step) / self.final_temperature_step)
self.set_temperature(temp)

View File

@ -2,7 +2,7 @@ import os
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
from apex.parallel import DistributedDataParallel
import utils.util
from apex import amp