forked from mrq/DL-Art-School
BSO improvement to make it work with distributed optimizers
This commit is contained in:
parent
836eb08afb
commit
1e28e02f98
|
@ -2,6 +2,7 @@ import math
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import distributed
|
||||||
|
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
@ -81,10 +82,16 @@ class GradientDirectionOptimizer(BatchSizeOptimizer):
|
||||||
if model._gradient_direction_optimizer_prior_grads is not None:
|
if model._gradient_direction_optimizer_prior_grads is not None:
|
||||||
cur_dir = [self.vector_angle(lgrad, cgrad) for lgrad, cgrad in zip(model._gradient_direction_optimizer_prior_grads, cur_grads)]
|
cur_dir = [self.vector_angle(lgrad, cgrad) for lgrad, cgrad in zip(model._gradient_direction_optimizer_prior_grads, cur_grads)]
|
||||||
delta_dir = [(cdir - ldir) for cdir, ldir in zip(cur_dir, model._gradient_direction_optimizer_prior_directions)]
|
delta_dir = [(cdir - ldir) for cdir, ldir in zip(cur_dir, model._gradient_direction_optimizer_prior_directions)]
|
||||||
delta_delta_dir = torch.stack([pdd - cdd for pdd, cdd in zip(model._gradient_direction_optimizer_direction_change_magnitudes, delta_dir)]).mean().item()
|
delta_delta_dir = torch.stack([pdd - cdd for pdd, cdd in zip(model._gradient_direction_optimizer_direction_change_magnitudes, delta_dir)]).mean()
|
||||||
model._gradient_direction_optimizer_prior_directions = cur_dir
|
model._gradient_direction_optimizer_prior_directions = cur_dir
|
||||||
model._gradient_direction_optimizer_direction_change_magnitudes = delta_dir
|
model._gradient_direction_optimizer_direction_change_magnitudes = delta_dir
|
||||||
if delta_delta_dir < 0 or model._gradient_direction_optimizer_step >= self.max_full_batches:
|
|
||||||
|
# For distributed optimizers, like ZeroRedundancyAdam, we need to reach a consensus as to whether or not to reduce.
|
||||||
|
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
||||||
|
distributed.all_reduce(delta_delta_dir)
|
||||||
|
delta_delta_dir = delta_delta_dir / distributed.get_world_size()
|
||||||
|
|
||||||
|
if delta_delta_dir.item() < 0 or model._gradient_direction_optimizer_step >= self.max_full_batches:
|
||||||
# <0 means the gradient direction is getting larger. Halt batch accumulation here.
|
# <0 means the gradient direction is getting larger. Halt batch accumulation here.
|
||||||
model._gradient_direction_optimizer_finished = True
|
model._gradient_direction_optimizer_finished = True
|
||||||
self.record_number_steps(model._gradient_direction_optimizer_step)
|
self.record_number_steps(model._gradient_direction_optimizer_step)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user