BSO improvement to make it work with distributed optimizers

This commit is contained in:
James Betker 2022-02-10 09:53:13 -07:00
parent 836eb08afb
commit 1e28e02f98

View File

@ -2,6 +2,7 @@ import math
import random
import torch
from torch import distributed
from utils.util import opt_get
@ -81,10 +82,16 @@ class GradientDirectionOptimizer(BatchSizeOptimizer):
if model._gradient_direction_optimizer_prior_grads is not None:
cur_dir = [self.vector_angle(lgrad, cgrad) for lgrad, cgrad in zip(model._gradient_direction_optimizer_prior_grads, cur_grads)]
delta_dir = [(cdir - ldir) for cdir, ldir in zip(cur_dir, model._gradient_direction_optimizer_prior_directions)]
delta_delta_dir = torch.stack([pdd - cdd for pdd, cdd in zip(model._gradient_direction_optimizer_direction_change_magnitudes, delta_dir)]).mean().item()
delta_delta_dir = torch.stack([pdd - cdd for pdd, cdd in zip(model._gradient_direction_optimizer_direction_change_magnitudes, delta_dir)]).mean()
model._gradient_direction_optimizer_prior_directions = cur_dir
model._gradient_direction_optimizer_direction_change_magnitudes = delta_dir
if delta_delta_dir < 0 or model._gradient_direction_optimizer_step >= self.max_full_batches:
# For distributed optimizers, like ZeroRedundancyAdam, we need to reach a consensus as to whether or not to reduce.
if distributed.is_initialized() and distributed.get_world_size() > 1:
distributed.all_reduce(delta_delta_dir)
delta_delta_dir = delta_delta_dir / distributed.get_world_size()
if delta_delta_dir.item() < 0 or model._gradient_direction_optimizer_step >= self.max_full_batches:
# <0 means the gradient direction is getting larger. Halt batch accumulation here.
model._gradient_direction_optimizer_finished = True
self.record_number_steps(model._gradient_direction_optimizer_step)