This commit is contained in:
James Betker 2022-02-10 20:54:51 -07:00
parent 1e28e02f98
commit 23a310b488

View File

@ -3,6 +3,7 @@ import random
import torch
from torch import distributed
from torch._C._distributed_c10d import ReduceOp
from utils.util import opt_get
@ -69,7 +70,7 @@ class GradientDirectionOptimizer(BatchSizeOptimizer):
num_params = min(len(all_params), self.parameters_to_poll)
model._gradient_direction_optimizer_params = random.sample(all_params, num_params)
model._gradient_direction_optimizer_prior_directions = [0 for _ in range(num_params)]
model._gradient_direction_optimizer_direction_change_magnitudes = [math.pi for _ in range(num_params)]
model._gradient_direction_optimizer_stopped_decreasing = [False for _ in range(num_params)]
model._gradient_direction_optimizer_prior_grads = None
model._gradient_direction_optimizer_step = 0
model._gradient_direction_optimizer_finished = False
@ -82,16 +83,17 @@ class GradientDirectionOptimizer(BatchSizeOptimizer):
if model._gradient_direction_optimizer_prior_grads is not None:
cur_dir = [self.vector_angle(lgrad, cgrad) for lgrad, cgrad in zip(model._gradient_direction_optimizer_prior_grads, cur_grads)]
delta_dir = [(cdir - ldir) for cdir, ldir in zip(cur_dir, model._gradient_direction_optimizer_prior_directions)]
delta_delta_dir = torch.stack([pdd - cdd for pdd, cdd in zip(model._gradient_direction_optimizer_direction_change_magnitudes, delta_dir)]).mean()
model._gradient_direction_optimizer_prior_directions = cur_dir
model._gradient_direction_optimizer_direction_change_magnitudes = delta_dir
model._gradient_direction_optimizer_stopped_decreasing = [sd or dd < 0 for sd, dd in zip(model._gradient_direction_optimizer_stopped_decreasing, delta_dir)]
all_finished = all(model._gradient_direction_optimizer_stopped_decreasing)
# For distributed optimizers, like ZeroRedundancyAdam, we need to reach a consensus as to whether or not to reduce.
if distributed.is_initialized() and distributed.get_world_size() > 1:
distributed.all_reduce(delta_delta_dir)
delta_delta_dir = delta_delta_dir / distributed.get_world_size()
all_finished = torch.tensor(all_finished)
distributed.all_reduce(all_finished, ReduceOp.BAND)
all_finished = torch.all(all_finished)
if delta_delta_dir.item() < 0 or model._gradient_direction_optimizer_step >= self.max_full_batches:
if all_finished or model._gradient_direction_optimizer_step >= self.max_full_batches:
# <0 means the gradient direction is getting larger. Halt batch accumulation here.
model._gradient_direction_optimizer_finished = True
self.record_number_steps(model._gradient_direction_optimizer_step)