diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 4f2b7636..079da6e1 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -3,7 +3,8 @@ import os import torch from apex import amp -from torch.nn.parallel import DataParallel, DistributedDataParallel +from apex.parallel import DistributedDataParallel +from torch.nn.parallel import DataParallel import torch.nn as nn import models.lr_scheduler as lr_scheduler @@ -107,9 +108,7 @@ class ExtensibleTrainer(BaseModel): dnets = [] for anet in amp_nets: if opt['dist']: - dnet = DistributedDataParallel(anet, - device_ids=[torch.cuda.current_device()], - find_unused_parameters=False) + dnet = DistributedDataParallel(anet, delay_allreduce=True) else: dnet = DataParallel(anet) if self.is_train: @@ -313,4 +312,4 @@ class ExtensibleTrainer(BaseModel): def force_restore_swapout(self): # Legacy method. Do nothing. - pass \ No newline at end of file + pass diff --git a/codes/models/archs/StructuredSwitchedGenerator.py b/codes/models/archs/StructuredSwitchedGenerator.py index b5b37500..8c74b313 100644 --- a/codes/models/archs/StructuredSwitchedGenerator.py +++ b/codes/models/archs/StructuredSwitchedGenerator.py @@ -476,6 +476,10 @@ class StackedSwitchGenerator5Layer(nn.Module): def update_for_step(self, step, experiments_path='.'): if self.attentions: + # All-reduce the attention norm. + for sw in self.switches: + sw.switch.reduce_norm_params() + temp = max(1, 1 + self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step) self.set_temperature(temp) @@ -496,4 +500,4 @@ class StackedSwitchGenerator5Layer(nn.Module): for i in range(len(means)): val["switch_%i_specificity" % (i,)] = means[i] val["switch_%i_histogram" % (i,)] = hists[i] - return val \ No newline at end of file + return val diff --git a/codes/models/base_model.py b/codes/models/base_model.py index 0b092520..2672bfa5 100644 --- a/codes/models/base_model.py +++ b/codes/models/base_model.py @@ -2,7 +2,7 @@ import os from collections import OrderedDict import torch import torch.nn as nn -from torch.nn.parallel import DistributedDataParallel +from apex.parallel import DistributedDataParallel import utils.util from apex import amp