forked from mrq/DL-Art-School
batch_size_optimizer works. sweet! no more tuning batch sizes.
This commit is contained in:
parent
18938248e4
commit
3d946356f8
|
@ -299,7 +299,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_encoder_build_ctc_alignments_medium/train_encoder_build_ctc_alignments.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_encoder_build_ctc_alignments.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -279,7 +279,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
# Now do a forward and backward pass for each gradient accumulation step.
|
||||
new_states = {}
|
||||
self.batch_size_optimizer.focus(step.get_optimizers()[-1])
|
||||
self.batch_size_optimizer.focus(net)
|
||||
for m in range(self.batch_factor):
|
||||
ns = step.do_forward_backward(state, m, step_num, train=train_step, no_ddp_sync=(m+1 < self.batch_factor))
|
||||
for k, v in ns.items():
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
import math
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
from utils.util import opt_get
|
||||
|
@ -33,34 +36,73 @@ class MegabatchBatchSizeOptimizer(BatchSizeOptimizer):
|
|||
|
||||
# BatchSizeOptimizer that uses the gradient direction of a few parameters to determine when to step.
|
||||
# Very similar to what is described in https://aclanthology.org/2020.acl-main.323.pdf
|
||||
# Special note: this optimizer will ALWAYS accumulate, at a minimum, 3 batches. Plan accordingly.
|
||||
class GradientDirectionOptimizer(BatchSizeOptimizer):
|
||||
def __init__(self, opt_train):
|
||||
self.mbf = opt_train['mega_batch_factor']
|
||||
self.opt = opt_train['batch_size_optimizer']
|
||||
self.max_full_batches = opt_get(self.opt, ['max_full_batches'], 10)
|
||||
self.parameters_to_poll = opt_get(self.opt, ['poll_parameters'], 8)
|
||||
self.recalculate_directions_every = opt_get(self.opt, ['recalculate_directions_steps'], 1)
|
||||
self.last_number_iterations = 0
|
||||
self.current_model = None
|
||||
|
||||
# Metrics
|
||||
self.steps_taken = 0
|
||||
self.last_number_iterations = torch.zeros((128,))
|
||||
self.last_number_iterations_i = 0
|
||||
self.last_number_iterations_filled = False
|
||||
|
||||
def vector_angle(self, v1, v2):
|
||||
if torch.all(v1 == 0) or torch.all(v2 == 0):
|
||||
return torch.tensor(0, device=v1.device)
|
||||
with torch.no_grad():
|
||||
v1 = v1.flatten()
|
||||
v2 = v2.flatten()
|
||||
v1_norm = (v1 ** 2).sum().sqrt()
|
||||
v2_norm = (v2 ** 2).sum().sqrt()
|
||||
angle = torch.arccos((v1 * v2) / (v1_norm * v2_norm))
|
||||
angle = torch.arccos((torch.dot(v1, v2)) / (v1_norm * v2_norm))
|
||||
return angle
|
||||
|
||||
def focus(self, optimizer):
|
||||
optimizer._gradient_direction_optimizer_params = []
|
||||
optimizer._gradient_direction_optimizer_prior_directions = []
|
||||
optimizer._gradient_direction_optimizer_prior_grads = []
|
||||
optimizer._gradient_direction_optimizer_direction_change_magnitudes = []
|
||||
optimizer._gradient_direction_optimizer_step = 0
|
||||
self.current_opt = optimizer
|
||||
def focus(self, model):
|
||||
if not hasattr(model, '_gradient_direction_optimizer_finished') or model._gradient_direction_optimizer_finished:
|
||||
all_params = list(filter(lambda t: '.weight' in t[0] and t[1].requires_grad, list(model.named_parameters()))) # Extracts weight parameters. Who cares about biases anyways? :)
|
||||
num_params = min(len(all_params), self.parameters_to_poll)
|
||||
model._gradient_direction_optimizer_params = random.sample(all_params, num_params)
|
||||
model._gradient_direction_optimizer_prior_directions = [0 for _ in range(num_params)]
|
||||
model._gradient_direction_optimizer_direction_change_magnitudes = [math.pi for _ in range(num_params)]
|
||||
model._gradient_direction_optimizer_prior_grads = None
|
||||
model._gradient_direction_optimizer_step = 0
|
||||
model._gradient_direction_optimizer_finished = False
|
||||
self.current_model = model
|
||||
|
||||
def should_step(self, it):
|
||||
self.last_number_iterations += 1
|
||||
model = self.current_model
|
||||
model._gradient_direction_optimizer_step += 1
|
||||
cur_grads = [p.grad.detach().clone() for k, p in model._gradient_direction_optimizer_params]
|
||||
if model._gradient_direction_optimizer_prior_grads is not None:
|
||||
cur_dir = [self.vector_angle(lgrad, cgrad) for lgrad, cgrad in zip(model._gradient_direction_optimizer_prior_grads, cur_grads)]
|
||||
delta_dir = [(cdir - ldir) for cdir, ldir in zip(cur_dir, model._gradient_direction_optimizer_prior_directions)]
|
||||
delta_delta_dir = torch.stack([pdd - cdd for pdd, cdd in zip(model._gradient_direction_optimizer_direction_change_magnitudes, delta_dir)]).mean().item()
|
||||
model._gradient_direction_optimizer_prior_directions = cur_dir
|
||||
model._gradient_direction_optimizer_direction_change_magnitudes = delta_dir
|
||||
if delta_delta_dir < 0 or model._gradient_direction_optimizer_step >= self.max_full_batches:
|
||||
# <0 means the gradient direction is getting larger. Halt batch accumulation here.
|
||||
model._gradient_direction_optimizer_finished = True
|
||||
self.record_number_steps(model._gradient_direction_optimizer_step)
|
||||
return True
|
||||
model._gradient_direction_optimizer_prior_grads = cur_grads
|
||||
return False
|
||||
|
||||
def record_number_steps(self, steps):
|
||||
self.last_number_iterations[self.last_number_iterations_i] = steps
|
||||
if self.last_number_iterations_i == self.last_number_iterations.shape[0]-1:
|
||||
self.last_number_iterations_filled = True
|
||||
self.last_number_iterations_i = (self.last_number_iterations_i + 1) % self.last_number_iterations.shape[0]
|
||||
self.steps_taken += 1
|
||||
|
||||
def get_statistics(self):
|
||||
return {"last_number_iterations_before_step": self.last_number_iterations}
|
||||
res = {"batch_size_opt_total_steps": self.steps_taken}
|
||||
if self.last_number_iterations_filled:
|
||||
res["batch_size_opt_avg_iterations_per_step"] = self.last_number_iterations.mean().item()
|
||||
else:
|
||||
res["batch_size_opt_avg_iterations_per_step"] = self.last_number_iterations[:self.last_number_iterations_i].mean().item()
|
||||
return res
|
||||
|
|
Loading…
Reference in New Issue
Block a user