From 3d946356f8c821969af02a63b669b3468ed337de Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 9 Feb 2022 14:26:23 -0700 Subject: [PATCH] batch_size_optimizer works. sweet! no more tuning batch sizes. --- codes/train.py | 2 +- codes/trainer/ExtensibleTrainer.py | 2 +- codes/trainer/batch_size_optimizer.py | 66 ++++++++++++++++++++++----- 3 files changed, 56 insertions(+), 14 deletions(-) diff --git a/codes/train.py b/codes/train.py index e6819ac1..be9cabac 100644 --- a/codes/train.py +++ b/codes/train.py @@ -299,7 +299,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_encoder_build_ctc_alignments_medium/train_encoder_build_ctc_alignments.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_encoder_build_ctc_alignments.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 5dab9548..2e4f11d8 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -279,7 +279,7 @@ class ExtensibleTrainer(BaseModel): # Now do a forward and backward pass for each gradient accumulation step. new_states = {} - self.batch_size_optimizer.focus(step.get_optimizers()[-1]) + self.batch_size_optimizer.focus(net) for m in range(self.batch_factor): ns = step.do_forward_backward(state, m, step_num, train=train_step, no_ddp_sync=(m+1 < self.batch_factor)) for k, v in ns.items(): diff --git a/codes/trainer/batch_size_optimizer.py b/codes/trainer/batch_size_optimizer.py index c2695935..dd4c46b4 100644 --- a/codes/trainer/batch_size_optimizer.py +++ b/codes/trainer/batch_size_optimizer.py @@ -1,3 +1,6 @@ +import math +import random + import torch from utils.util import opt_get @@ -33,34 +36,73 @@ class MegabatchBatchSizeOptimizer(BatchSizeOptimizer): # BatchSizeOptimizer that uses the gradient direction of a few parameters to determine when to step. # Very similar to what is described in https://aclanthology.org/2020.acl-main.323.pdf +# Special note: this optimizer will ALWAYS accumulate, at a minimum, 3 batches. Plan accordingly. class GradientDirectionOptimizer(BatchSizeOptimizer): def __init__(self, opt_train): - self.mbf = opt_train['mega_batch_factor'] self.opt = opt_train['batch_size_optimizer'] self.max_full_batches = opt_get(self.opt, ['max_full_batches'], 10) self.parameters_to_poll = opt_get(self.opt, ['poll_parameters'], 8) self.recalculate_directions_every = opt_get(self.opt, ['recalculate_directions_steps'], 1) - self.last_number_iterations = 0 + self.current_model = None + + # Metrics + self.steps_taken = 0 + self.last_number_iterations = torch.zeros((128,)) + self.last_number_iterations_i = 0 + self.last_number_iterations_filled = False def vector_angle(self, v1, v2): + if torch.all(v1 == 0) or torch.all(v2 == 0): + return torch.tensor(0, device=v1.device) with torch.no_grad(): v1 = v1.flatten() v2 = v2.flatten() v1_norm = (v1 ** 2).sum().sqrt() v2_norm = (v2 ** 2).sum().sqrt() - angle = torch.arccos((v1 * v2) / (v1_norm * v2_norm)) + angle = torch.arccos((torch.dot(v1, v2)) / (v1_norm * v2_norm)) return angle - def focus(self, optimizer): - optimizer._gradient_direction_optimizer_params = [] - optimizer._gradient_direction_optimizer_prior_directions = [] - optimizer._gradient_direction_optimizer_prior_grads = [] - optimizer._gradient_direction_optimizer_direction_change_magnitudes = [] - optimizer._gradient_direction_optimizer_step = 0 - self.current_opt = optimizer + def focus(self, model): + if not hasattr(model, '_gradient_direction_optimizer_finished') or model._gradient_direction_optimizer_finished: + all_params = list(filter(lambda t: '.weight' in t[0] and t[1].requires_grad, list(model.named_parameters()))) # Extracts weight parameters. Who cares about biases anyways? :) + num_params = min(len(all_params), self.parameters_to_poll) + model._gradient_direction_optimizer_params = random.sample(all_params, num_params) + model._gradient_direction_optimizer_prior_directions = [0 for _ in range(num_params)] + model._gradient_direction_optimizer_direction_change_magnitudes = [math.pi for _ in range(num_params)] + model._gradient_direction_optimizer_prior_grads = None + model._gradient_direction_optimizer_step = 0 + model._gradient_direction_optimizer_finished = False + self.current_model = model def should_step(self, it): - self.last_number_iterations += 1 + model = self.current_model + model._gradient_direction_optimizer_step += 1 + cur_grads = [p.grad.detach().clone() for k, p in model._gradient_direction_optimizer_params] + if model._gradient_direction_optimizer_prior_grads is not None: + cur_dir = [self.vector_angle(lgrad, cgrad) for lgrad, cgrad in zip(model._gradient_direction_optimizer_prior_grads, cur_grads)] + delta_dir = [(cdir - ldir) for cdir, ldir in zip(cur_dir, model._gradient_direction_optimizer_prior_directions)] + delta_delta_dir = torch.stack([pdd - cdd for pdd, cdd in zip(model._gradient_direction_optimizer_direction_change_magnitudes, delta_dir)]).mean().item() + model._gradient_direction_optimizer_prior_directions = cur_dir + model._gradient_direction_optimizer_direction_change_magnitudes = delta_dir + if delta_delta_dir < 0 or model._gradient_direction_optimizer_step >= self.max_full_batches: + # <0 means the gradient direction is getting larger. Halt batch accumulation here. + model._gradient_direction_optimizer_finished = True + self.record_number_steps(model._gradient_direction_optimizer_step) + return True + model._gradient_direction_optimizer_prior_grads = cur_grads + return False + + def record_number_steps(self, steps): + self.last_number_iterations[self.last_number_iterations_i] = steps + if self.last_number_iterations_i == self.last_number_iterations.shape[0]-1: + self.last_number_iterations_filled = True + self.last_number_iterations_i = (self.last_number_iterations_i + 1) % self.last_number_iterations.shape[0] + self.steps_taken += 1 def get_statistics(self): - return {"last_number_iterations_before_step": self.last_number_iterations} + res = {"batch_size_opt_total_steps": self.steps_taken} + if self.last_number_iterations_filled: + res["batch_size_opt_avg_iterations_per_step"] = self.last_number_iterations.mean().item() + else: + res["batch_size_opt_avg_iterations_per_step"] = self.last_number_iterations[:self.last_number_iterations_i].mean().item() + return res