Update BSO to use the proper step size

This commit is contained in:
James Betker 2022-02-10 09:44:15 -07:00
parent 820a29f81e
commit 836eb08afb

View File

@ -36,7 +36,7 @@ class MegabatchBatchSizeOptimizer(BatchSizeOptimizer):
# BatchSizeOptimizer that uses the gradient direction of a few parameters to determine when to step.
# Very similar to what is described in https://aclanthology.org/2020.acl-main.323.pdf
# Special note: this optimizer will ALWAYS accumulate, at a minimum, 3 batches. Plan accordingly.
# Special note: this class will ALWAYS accumulate, at a minimum, 3 batches. Plan accordingly.
class GradientDirectionOptimizer(BatchSizeOptimizer):
def __init__(self, opt_train):
self.opt = opt_train['batch_size_optimizer']
@ -88,6 +88,10 @@ class GradientDirectionOptimizer(BatchSizeOptimizer):
# <0 means the gradient direction is getting larger. Halt batch accumulation here.
model._gradient_direction_optimizer_finished = True
self.record_number_steps(model._gradient_direction_optimizer_step)
# Fix the gradients. We've accumulated _gradient_direction_optimizer_step steps total, so we need to divide the grads by that.
for p in model.parameters():
if p.requires_grad:
p.grad = p.grad / model._gradient_direction_optimizer_step
return True
model._gradient_direction_optimizer_prior_grads = cur_grads
return False