2022-02-09 21:26:23 +00:00
import math
import random
2022-02-09 06:51:31 +00:00
import torch
from utils . util import opt_get
def create_batch_size_optimizer ( opt_train ) :
if ' batch_size_optimizer ' in opt_train . keys ( ) :
if opt_train [ ' batch_size_optimizer ' ] [ ' type ' ] == ' gradient_direction ' :
return GradientDirectionOptimizer ( opt_train )
return MegabatchBatchSizeOptimizer ( opt_train )
# Base class for BatchSizeOptimizers.
class BatchSizeOptimizer :
def focus ( self , optimizer ) :
pass
def should_step ( self , it ) :
raise NotImplementedError
def get_statistics ( self ) :
return { }
# BatchSizeOptimizer that just steps every megabatch.
class MegabatchBatchSizeOptimizer ( BatchSizeOptimizer ) :
def __init__ ( self , opt_train ) :
pass
def should_step ( self , it ) :
return True
# BatchSizeOptimizer that uses the gradient direction of a few parameters to determine when to step.
# Very similar to what is described in https://aclanthology.org/2020.acl-main.323.pdf
2022-02-09 21:26:23 +00:00
# Special note: this optimizer will ALWAYS accumulate, at a minimum, 3 batches. Plan accordingly.
2022-02-09 06:51:31 +00:00
class GradientDirectionOptimizer ( BatchSizeOptimizer ) :
def __init__ ( self , opt_train ) :
self . opt = opt_train [ ' batch_size_optimizer ' ]
self . max_full_batches = opt_get ( self . opt , [ ' max_full_batches ' ] , 10 )
self . parameters_to_poll = opt_get ( self . opt , [ ' poll_parameters ' ] , 8 )
self . recalculate_directions_every = opt_get ( self . opt , [ ' recalculate_directions_steps ' ] , 1 )
2022-02-09 21:26:23 +00:00
self . current_model = None
# Metrics
self . steps_taken = 0
self . last_number_iterations = torch . zeros ( ( 128 , ) )
self . last_number_iterations_i = 0
self . last_number_iterations_filled = False
2022-02-09 06:51:31 +00:00
def vector_angle ( self , v1 , v2 ) :
2022-02-09 21:26:23 +00:00
if torch . all ( v1 == 0 ) or torch . all ( v2 == 0 ) :
return torch . tensor ( 0 , device = v1 . device )
2022-02-09 06:51:31 +00:00
with torch . no_grad ( ) :
v1 = v1 . flatten ( )
v2 = v2 . flatten ( )
v1_norm = ( v1 * * 2 ) . sum ( ) . sqrt ( )
v2_norm = ( v2 * * 2 ) . sum ( ) . sqrt ( )
2022-02-09 21:26:23 +00:00
angle = torch . arccos ( ( torch . dot ( v1 , v2 ) ) / ( v1_norm * v2_norm ) )
2022-02-09 06:51:31 +00:00
return angle
2022-02-09 21:26:23 +00:00
def focus ( self , model ) :
if not hasattr ( model , ' _gradient_direction_optimizer_finished ' ) or model . _gradient_direction_optimizer_finished :
all_params = list ( filter ( lambda t : ' .weight ' in t [ 0 ] and t [ 1 ] . requires_grad , list ( model . named_parameters ( ) ) ) ) # Extracts weight parameters. Who cares about biases anyways? :)
num_params = min ( len ( all_params ) , self . parameters_to_poll )
model . _gradient_direction_optimizer_params = random . sample ( all_params , num_params )
model . _gradient_direction_optimizer_prior_directions = [ 0 for _ in range ( num_params ) ]
model . _gradient_direction_optimizer_direction_change_magnitudes = [ math . pi for _ in range ( num_params ) ]
model . _gradient_direction_optimizer_prior_grads = None
model . _gradient_direction_optimizer_step = 0
model . _gradient_direction_optimizer_finished = False
self . current_model = model
2022-02-09 06:51:31 +00:00
def should_step ( self , it ) :
2022-02-09 21:26:23 +00:00
model = self . current_model
model . _gradient_direction_optimizer_step + = 1
cur_grads = [ p . grad . detach ( ) . clone ( ) for k , p in model . _gradient_direction_optimizer_params ]
if model . _gradient_direction_optimizer_prior_grads is not None :
cur_dir = [ self . vector_angle ( lgrad , cgrad ) for lgrad , cgrad in zip ( model . _gradient_direction_optimizer_prior_grads , cur_grads ) ]
delta_dir = [ ( cdir - ldir ) for cdir , ldir in zip ( cur_dir , model . _gradient_direction_optimizer_prior_directions ) ]
delta_delta_dir = torch . stack ( [ pdd - cdd for pdd , cdd in zip ( model . _gradient_direction_optimizer_direction_change_magnitudes , delta_dir ) ] ) . mean ( ) . item ( )
model . _gradient_direction_optimizer_prior_directions = cur_dir
model . _gradient_direction_optimizer_direction_change_magnitudes = delta_dir
if delta_delta_dir < 0 or model . _gradient_direction_optimizer_step > = self . max_full_batches :
# <0 means the gradient direction is getting larger. Halt batch accumulation here.
model . _gradient_direction_optimizer_finished = True
self . record_number_steps ( model . _gradient_direction_optimizer_step )
return True
model . _gradient_direction_optimizer_prior_grads = cur_grads
return False
def record_number_steps ( self , steps ) :
self . last_number_iterations [ self . last_number_iterations_i ] = steps
if self . last_number_iterations_i == self . last_number_iterations . shape [ 0 ] - 1 :
self . last_number_iterations_filled = True
self . last_number_iterations_i = ( self . last_number_iterations_i + 1 ) % self . last_number_iterations . shape [ 0 ]
self . steps_taken + = 1
2022-02-09 06:51:31 +00:00
def get_statistics ( self ) :
2022-02-09 21:26:23 +00:00
res = { " batch_size_opt_total_steps " : self . steps_taken }
if self . last_number_iterations_filled :
res [ " batch_size_opt_avg_iterations_per_step " ] = self . last_number_iterations . mean ( ) . item ( )
else :
res [ " batch_size_opt_avg_iterations_per_step " ] = self . last_number_iterations [ : self . last_number_iterations_i ] . mean ( ) . item ( )
return res