diff --git a/codes/trainer/batch_size_optimizer.py b/codes/trainer/batch_size_optimizer.py index 1271ba40..8730dc4e 100644 --- a/codes/trainer/batch_size_optimizer.py +++ b/codes/trainer/batch_size_optimizer.py @@ -2,6 +2,7 @@ import math import random import torch +from torch import distributed from utils.util import opt_get @@ -81,10 +82,16 @@ class GradientDirectionOptimizer(BatchSizeOptimizer): if model._gradient_direction_optimizer_prior_grads is not None: cur_dir = [self.vector_angle(lgrad, cgrad) for lgrad, cgrad in zip(model._gradient_direction_optimizer_prior_grads, cur_grads)] delta_dir = [(cdir - ldir) for cdir, ldir in zip(cur_dir, model._gradient_direction_optimizer_prior_directions)] - delta_delta_dir = torch.stack([pdd - cdd for pdd, cdd in zip(model._gradient_direction_optimizer_direction_change_magnitudes, delta_dir)]).mean().item() + delta_delta_dir = torch.stack([pdd - cdd for pdd, cdd in zip(model._gradient_direction_optimizer_direction_change_magnitudes, delta_dir)]).mean() model._gradient_direction_optimizer_prior_directions = cur_dir model._gradient_direction_optimizer_direction_change_magnitudes = delta_dir - if delta_delta_dir < 0 or model._gradient_direction_optimizer_step >= self.max_full_batches: + + # For distributed optimizers, like ZeroRedundancyAdam, we need to reach a consensus as to whether or not to reduce. + if distributed.is_initialized() and distributed.get_world_size() > 1: + distributed.all_reduce(delta_delta_dir) + delta_delta_dir = delta_delta_dir / distributed.get_world_size() + + if delta_delta_dir.item() < 0 or model._gradient_direction_optimizer_step >= self.max_full_batches: # <0 means the gradient direction is getting larger. Halt batch accumulation here. model._gradient_direction_optimizer_finished = True self.record_number_steps(model._gradient_direction_optimizer_step)