BSO improvement to make it work with distributed optimizers
This commit is contained in:
parent
836eb08afb
commit
1e28e02f98
|
@ -2,6 +2,7 @@ import math
|
|||
import random
|
||||
|
||||
import torch
|
||||
from torch import distributed
|
||||
|
||||
from utils.util import opt_get
|
||||
|
||||
|
@ -81,10 +82,16 @@ class GradientDirectionOptimizer(BatchSizeOptimizer):
|
|||
if model._gradient_direction_optimizer_prior_grads is not None:
|
||||
cur_dir = [self.vector_angle(lgrad, cgrad) for lgrad, cgrad in zip(model._gradient_direction_optimizer_prior_grads, cur_grads)]
|
||||
delta_dir = [(cdir - ldir) for cdir, ldir in zip(cur_dir, model._gradient_direction_optimizer_prior_directions)]
|
||||
delta_delta_dir = torch.stack([pdd - cdd for pdd, cdd in zip(model._gradient_direction_optimizer_direction_change_magnitudes, delta_dir)]).mean().item()
|
||||
delta_delta_dir = torch.stack([pdd - cdd for pdd, cdd in zip(model._gradient_direction_optimizer_direction_change_magnitudes, delta_dir)]).mean()
|
||||
model._gradient_direction_optimizer_prior_directions = cur_dir
|
||||
model._gradient_direction_optimizer_direction_change_magnitudes = delta_dir
|
||||
if delta_delta_dir < 0 or model._gradient_direction_optimizer_step >= self.max_full_batches:
|
||||
|
||||
# For distributed optimizers, like ZeroRedundancyAdam, we need to reach a consensus as to whether or not to reduce.
|
||||
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
||||
distributed.all_reduce(delta_delta_dir)
|
||||
delta_delta_dir = delta_delta_dir / distributed.get_world_size()
|
||||
|
||||
if delta_delta_dir.item() < 0 or model._gradient_direction_optimizer_step >= self.max_full_batches:
|
||||
# <0 means the gradient direction is getting larger. Halt batch accumulation here.
|
||||
model._gradient_direction_optimizer_finished = True
|
||||
self.record_number_steps(model._gradient_direction_optimizer_step)
|
||||
|
|
Loading…
Reference in New Issue
Block a user