diff --git a/codes/trainer/batch_size_optimizer.py b/codes/trainer/batch_size_optimizer.py index dd4c46b4..1271ba40 100644 --- a/codes/trainer/batch_size_optimizer.py +++ b/codes/trainer/batch_size_optimizer.py @@ -36,7 +36,7 @@ class MegabatchBatchSizeOptimizer(BatchSizeOptimizer): # BatchSizeOptimizer that uses the gradient direction of a few parameters to determine when to step. # Very similar to what is described in https://aclanthology.org/2020.acl-main.323.pdf -# Special note: this optimizer will ALWAYS accumulate, at a minimum, 3 batches. Plan accordingly. +# Special note: this class will ALWAYS accumulate, at a minimum, 3 batches. Plan accordingly. class GradientDirectionOptimizer(BatchSizeOptimizer): def __init__(self, opt_train): self.opt = opt_train['batch_size_optimizer'] @@ -88,6 +88,10 @@ class GradientDirectionOptimizer(BatchSizeOptimizer): # <0 means the gradient direction is getting larger. Halt batch accumulation here. model._gradient_direction_optimizer_finished = True self.record_number_steps(model._gradient_direction_optimizer_step) + # Fix the gradients. We've accumulated _gradient_direction_optimizer_step steps total, so we need to divide the grads by that. + for p in model.parameters(): + if p.requires_grad: + p.grad = p.grad / model._gradient_direction_optimizer_step return True model._gradient_direction_optimizer_prior_grads = cur_grads return False