Copied over Analysis Adam.

This commit is contained in:
Tim Dettmers 2021-10-21 10:20:41 -07:00
parent d06c5776e4
commit eaf35ab949

View File

@ -26,3 +26,202 @@ class Adam32bit(Optimizer2State):
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
class AnalysisAdam(torch.optim.Optimizer):
"""Implements 8-bit Adam and performs error analysis.
This implementation is modified from torch.optim.Adam based on:
`Fixed Weight Decay Regularization in Adam`
(see https://arxiv.org/abs/1711.05101)
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
.. _Adam\: A Method for Stochastic