Update BSO to use the proper step size
This commit is contained in:
parent
820a29f81e
commit
836eb08afb
|
@ -36,7 +36,7 @@ class MegabatchBatchSizeOptimizer(BatchSizeOptimizer):
|
||||||
|
|
||||||
# BatchSizeOptimizer that uses the gradient direction of a few parameters to determine when to step.
|
# BatchSizeOptimizer that uses the gradient direction of a few parameters to determine when to step.
|
||||||
# Very similar to what is described in https://aclanthology.org/2020.acl-main.323.pdf
|
# Very similar to what is described in https://aclanthology.org/2020.acl-main.323.pdf
|
||||||
# Special note: this optimizer will ALWAYS accumulate, at a minimum, 3 batches. Plan accordingly.
|
# Special note: this class will ALWAYS accumulate, at a minimum, 3 batches. Plan accordingly.
|
||||||
class GradientDirectionOptimizer(BatchSizeOptimizer):
|
class GradientDirectionOptimizer(BatchSizeOptimizer):
|
||||||
def __init__(self, opt_train):
|
def __init__(self, opt_train):
|
||||||
self.opt = opt_train['batch_size_optimizer']
|
self.opt = opt_train['batch_size_optimizer']
|
||||||
|
@ -88,6 +88,10 @@ class GradientDirectionOptimizer(BatchSizeOptimizer):
|
||||||
# <0 means the gradient direction is getting larger. Halt batch accumulation here.
|
# <0 means the gradient direction is getting larger. Halt batch accumulation here.
|
||||||
model._gradient_direction_optimizer_finished = True
|
model._gradient_direction_optimizer_finished = True
|
||||||
self.record_number_steps(model._gradient_direction_optimizer_step)
|
self.record_number_steps(model._gradient_direction_optimizer_step)
|
||||||
|
# Fix the gradients. We've accumulated _gradient_direction_optimizer_step steps total, so we need to divide the grads by that.
|
||||||
|
for p in model.parameters():
|
||||||
|
if p.requires_grad:
|
||||||
|
p.grad = p.grad / model._gradient_direction_optimizer_step
|
||||||
return True
|
return True
|
||||||
model._gradient_direction_optimizer_prior_grads = cur_grads
|
model._gradient_direction_optimizer_prior_grads = cur_grads
|
||||||
return False
|
return False
|
||||||
|
|
Loading…
Reference in New Issue
Block a user