import math import random import torch from utils.util import opt_get def create_batch_size_optimizer(opt_train): if 'batch_size_optimizer' in opt_train.keys(): if opt_train['batch_size_optimizer']['type'] == 'gradient_direction': return GradientDirectionOptimizer(opt_train) return MegabatchBatchSizeOptimizer(opt_train) # Base class for BatchSizeOptimizers. class BatchSizeOptimizer: def focus(self, optimizer): pass def should_step(self, it): raise NotImplementedError def get_statistics(self): return {} # BatchSizeOptimizer that just steps every megabatch. class MegabatchBatchSizeOptimizer(BatchSizeOptimizer): def __init__(self, opt_train): pass def should_step(self, it): return True # BatchSizeOptimizer that uses the gradient direction of a few parameters to determine when to step. # Very similar to what is described in https://aclanthology.org/2020.acl-main.323.pdf # Special note: this optimizer will ALWAYS accumulate, at a minimum, 3 batches. Plan accordingly. class GradientDirectionOptimizer(BatchSizeOptimizer): def __init__(self, opt_train): self.opt = opt_train['batch_size_optimizer'] self.max_full_batches = opt_get(self.opt, ['max_full_batches'], 10) self.parameters_to_poll = opt_get(self.opt, ['poll_parameters'], 8) self.recalculate_directions_every = opt_get(self.opt, ['recalculate_directions_steps'], 1) self.current_model = None # Metrics self.steps_taken = 0 self.last_number_iterations = torch.zeros((128,)) self.last_number_iterations_i = 0 self.last_number_iterations_filled = False def vector_angle(self, v1, v2): if torch.all(v1 == 0) or torch.all(v2 == 0): return torch.tensor(0, device=v1.device) with torch.no_grad(): v1 = v1.flatten() v2 = v2.flatten() v1_norm = (v1 ** 2).sum().sqrt() v2_norm = (v2 ** 2).sum().sqrt() angle = torch.arccos((torch.dot(v1, v2)) / (v1_norm * v2_norm)) return angle def focus(self, model): if not hasattr(model, '_gradient_direction_optimizer_finished') or model._gradient_direction_optimizer_finished: all_params = list(filter(lambda t: '.weight' in t[0] and t[1].requires_grad, list(model.named_parameters()))) # Extracts weight parameters. Who cares about biases anyways? :) num_params = min(len(all_params), self.parameters_to_poll) model._gradient_direction_optimizer_params = random.sample(all_params, num_params) model._gradient_direction_optimizer_prior_directions = [0 for _ in range(num_params)] model._gradient_direction_optimizer_direction_change_magnitudes = [math.pi for _ in range(num_params)] model._gradient_direction_optimizer_prior_grads = None model._gradient_direction_optimizer_step = 0 model._gradient_direction_optimizer_finished = False self.current_model = model def should_step(self, it): model = self.current_model model._gradient_direction_optimizer_step += 1 cur_grads = [p.grad.detach().clone() for k, p in model._gradient_direction_optimizer_params] if model._gradient_direction_optimizer_prior_grads is not None: cur_dir = [self.vector_angle(lgrad, cgrad) for lgrad, cgrad in zip(model._gradient_direction_optimizer_prior_grads, cur_grads)] delta_dir = [(cdir - ldir) for cdir, ldir in zip(cur_dir, model._gradient_direction_optimizer_prior_directions)] delta_delta_dir = torch.stack([pdd - cdd for pdd, cdd in zip(model._gradient_direction_optimizer_direction_change_magnitudes, delta_dir)]).mean().item() model._gradient_direction_optimizer_prior_directions = cur_dir model._gradient_direction_optimizer_direction_change_magnitudes = delta_dir if delta_delta_dir < 0 or model._gradient_direction_optimizer_step >= self.max_full_batches: # <0 means the gradient direction is getting larger. Halt batch accumulation here. model._gradient_direction_optimizer_finished = True self.record_number_steps(model._gradient_direction_optimizer_step) return True model._gradient_direction_optimizer_prior_grads = cur_grads return False def record_number_steps(self, steps): self.last_number_iterations[self.last_number_iterations_i] = steps if self.last_number_iterations_i == self.last_number_iterations.shape[0]-1: self.last_number_iterations_filled = True self.last_number_iterations_i = (self.last_number_iterations_i + 1) % self.last_number_iterations.shape[0] self.steps_taken += 1 def get_statistics(self): res = {"batch_size_opt_total_steps": self.steps_taken} if self.last_number_iterations_filled: res["batch_size_opt_avg_iterations_per_step"] = self.last_number_iterations.mean().item() else: res["batch_size_opt_avg_iterations_per_step"] = self.last_number_iterations[:self.last_number_iterations_i].mean().item() return res