From 23a310b48874b3123fdfd5bf7788112ececfda01 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 10 Feb 2022 20:54:51 -0700 Subject: [PATCH] Fix BSO --- codes/trainer/batch_size_optimizer.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/codes/trainer/batch_size_optimizer.py b/codes/trainer/batch_size_optimizer.py index 8730dc4e..354311ca 100644 --- a/codes/trainer/batch_size_optimizer.py +++ b/codes/trainer/batch_size_optimizer.py @@ -3,6 +3,7 @@ import random import torch from torch import distributed +from torch._C._distributed_c10d import ReduceOp from utils.util import opt_get @@ -69,7 +70,7 @@ class GradientDirectionOptimizer(BatchSizeOptimizer): num_params = min(len(all_params), self.parameters_to_poll) model._gradient_direction_optimizer_params = random.sample(all_params, num_params) model._gradient_direction_optimizer_prior_directions = [0 for _ in range(num_params)] - model._gradient_direction_optimizer_direction_change_magnitudes = [math.pi for _ in range(num_params)] + model._gradient_direction_optimizer_stopped_decreasing = [False for _ in range(num_params)] model._gradient_direction_optimizer_prior_grads = None model._gradient_direction_optimizer_step = 0 model._gradient_direction_optimizer_finished = False @@ -82,16 +83,17 @@ class GradientDirectionOptimizer(BatchSizeOptimizer): if model._gradient_direction_optimizer_prior_grads is not None: cur_dir = [self.vector_angle(lgrad, cgrad) for lgrad, cgrad in zip(model._gradient_direction_optimizer_prior_grads, cur_grads)] delta_dir = [(cdir - ldir) for cdir, ldir in zip(cur_dir, model._gradient_direction_optimizer_prior_directions)] - delta_delta_dir = torch.stack([pdd - cdd for pdd, cdd in zip(model._gradient_direction_optimizer_direction_change_magnitudes, delta_dir)]).mean() model._gradient_direction_optimizer_prior_directions = cur_dir - model._gradient_direction_optimizer_direction_change_magnitudes = delta_dir + model._gradient_direction_optimizer_stopped_decreasing = [sd or dd < 0 for sd, dd in zip(model._gradient_direction_optimizer_stopped_decreasing, delta_dir)] + all_finished = all(model._gradient_direction_optimizer_stopped_decreasing) # For distributed optimizers, like ZeroRedundancyAdam, we need to reach a consensus as to whether or not to reduce. if distributed.is_initialized() and distributed.get_world_size() > 1: - distributed.all_reduce(delta_delta_dir) - delta_delta_dir = delta_delta_dir / distributed.get_world_size() + all_finished = torch.tensor(all_finished) + distributed.all_reduce(all_finished, ReduceOp.BAND) + all_finished = torch.all(all_finished) - if delta_delta_dir.item() < 0 or model._gradient_direction_optimizer_step >= self.max_full_batches: + if all_finished or model._gradient_direction_optimizer_step >= self.max_full_batches: # <0 means the gradient direction is getting larger. Halt batch accumulation here. model._gradient_direction_optimizer_finished = True self.record_number_steps(model._gradient_direction_optimizer_step)