diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py index 8c8a8f4..53533ee 100644 --- a/bitsandbytes/optim/__init__.py +++ b/bitsandbytes/optim/__init__.py @@ -12,4 +12,5 @@ from .lamb import LAMB, LAMB8bit, LAMB32bit from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS from .optimizer import GlobalOptimManager from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit +from .lion import Lion, Lion8bit, Lion32bit from .sgd import SGD, SGD8bit, SGD32bit diff --git a/bitsandbytes/optim/lion.py b/bitsandbytes/optim/lion.py new file mode 100644 index 0000000..cd2a9da --- /dev/null +++ b/bitsandbytes/optim/lion.py @@ -0,0 +1,115 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from bitsandbytes.optim.optimizer import Optimizer1State + + +class Lion(Optimizer1State): + def __init__( + self, + params, + lr=1e-2, + alpha=0.99, + eps=1e-8, + weight_decay=0, + momentum=0, + centered=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + if alpha == 0: + raise NotImplementedError( + "RMSprop with alpha==0.0 is not supported!" + ) + if centered: + raise NotImplementedError("Centered RMSprop is not supported!") + super().__init__( + "rmsprop", + params, + lr, + (alpha, momentum), + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + + +class Lion8bit(Optimizer1State): + def __init__( + self, + params, + lr=1e-2, + alpha=0.99, + eps=1e-8, + weight_decay=0, + momentum=0, + centered=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + if alpha == 0: + raise NotImplementedError( + "RMSprop with alpha==0.0 is not supported!" + ) + if centered: + raise NotImplementedError("Centered RMSprop is not supported!") + super().__init__( + "rmsprop", + params, + lr, + (alpha, momentum), + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + + +class Lion32bit(Optimizer1State): + def __init__( + self, + params, + lr=1e-2, + alpha=0.99, + eps=1e-8, + weight_decay=0, + momentum=0, + centered=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + + if alpha == 0: + raise NotImplementedError( + "RMSprop with alpha==0.0 is not supported!" + ) + if centered: + raise NotImplementedError("Centered RMSprop is not supported!") + super().__init__( + "rmsprop", + params, + lr, + (alpha, momentum), + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + )