Fix critical bug in PytorchLARS().step: Undefined variable

This commit is contained in:
Tom Aarsen 2022-10-27 13:19:09 +02:00
parent f6978ae2a2
commit 4a05df34c2

View File

@ -181,7 +181,7 @@ class PytorchLARS(Optimizer):
state = self.state[p]
d_p = p.grad
if weight_decay != 0:
d_p = d_p.add(param, alpha=weight_decay)
d_p = d_p.add(p, alpha=weight_decay)
if momentum != 0:
buf = state.get("momentum_buffer", None)