forked from mrq/bitsandbytes-rocm
Copied over Analysis Adam.
This commit is contained in:
parent
d06c5776e4
commit
eaf35ab949
|
@ -26,3 +26,202 @@ class Adam32bit(Optimizer2State):
|
|||
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
|
||||
|
||||
|
||||
|
||||
class AnalysisAdam(torch.optim.Optimizer):
|
||||
"""Implements 8-bit Adam and performs error analysis.
|
||||
|
||||
This implementation is modified from torch.optim.Adam based on:
|
||||
`Fixed Weight Decay Regularization in Adam`
|
||||
(see https://arxiv.org/abs/1711.05101)
|
||||
|
||||
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
|
||||
|
||||
Arguments:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
||||
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
||||
|
||||
.. _Adam\: A Method for Stochastic |