forked from mrq/DL-Art-School
Fix BSO
This commit is contained in:
parent
1e28e02f98
commit
23a310b488
|
@ -3,6 +3,7 @@ import random
|
|||
|
||||
import torch
|
||||
from torch import distributed
|
||||
from torch._C._distributed_c10d import ReduceOp
|
||||
|
||||
from utils.util import opt_get
|
||||
|
||||
|
@ -69,7 +70,7 @@ class GradientDirectionOptimizer(BatchSizeOptimizer):
|
|||
num_params = min(len(all_params), self.parameters_to_poll)
|
||||
model._gradient_direction_optimizer_params = random.sample(all_params, num_params)
|
||||
model._gradient_direction_optimizer_prior_directions = [0 for _ in range(num_params)]
|
||||
model._gradient_direction_optimizer_direction_change_magnitudes = [math.pi for _ in range(num_params)]
|
||||
model._gradient_direction_optimizer_stopped_decreasing = [False for _ in range(num_params)]
|
||||
model._gradient_direction_optimizer_prior_grads = None
|
||||
model._gradient_direction_optimizer_step = 0
|
||||
model._gradient_direction_optimizer_finished = False
|
||||
|
@ -82,16 +83,17 @@ class GradientDirectionOptimizer(BatchSizeOptimizer):
|
|||
if model._gradient_direction_optimizer_prior_grads is not None:
|
||||
cur_dir = [self.vector_angle(lgrad, cgrad) for lgrad, cgrad in zip(model._gradient_direction_optimizer_prior_grads, cur_grads)]
|
||||
delta_dir = [(cdir - ldir) for cdir, ldir in zip(cur_dir, model._gradient_direction_optimizer_prior_directions)]
|
||||
delta_delta_dir = torch.stack([pdd - cdd for pdd, cdd in zip(model._gradient_direction_optimizer_direction_change_magnitudes, delta_dir)]).mean()
|
||||
model._gradient_direction_optimizer_prior_directions = cur_dir
|
||||
model._gradient_direction_optimizer_direction_change_magnitudes = delta_dir
|
||||
model._gradient_direction_optimizer_stopped_decreasing = [sd or dd < 0 for sd, dd in zip(model._gradient_direction_optimizer_stopped_decreasing, delta_dir)]
|
||||
all_finished = all(model._gradient_direction_optimizer_stopped_decreasing)
|
||||
|
||||
# For distributed optimizers, like ZeroRedundancyAdam, we need to reach a consensus as to whether or not to reduce.
|
||||
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
||||
distributed.all_reduce(delta_delta_dir)
|
||||
delta_delta_dir = delta_delta_dir / distributed.get_world_size()
|
||||
all_finished = torch.tensor(all_finished)
|
||||
distributed.all_reduce(all_finished, ReduceOp.BAND)
|
||||
all_finished = torch.all(all_finished)
|
||||
|
||||
if delta_delta_dir.item() < 0 or model._gradient_direction_optimizer_step >= self.max_full_batches:
|
||||
if all_finished or model._gradient_direction_optimizer_step >= self.max_full_batches:
|
||||
# <0 means the gradient direction is getting larger. Halt batch accumulation here.
|
||||
model._gradient_direction_optimizer_finished = True
|
||||
self.record_number_steps(model._gradient_direction_optimizer_step)
|
||||
|
|
Loading…
Reference in New Issue
Block a user