diff --git a/bitsandbytes/optim/lion.py b/bitsandbytes/optim/lion.py index cd2a9da..a2fb6af 100644 --- a/bitsandbytes/optim/lion.py +++ b/bitsandbytes/optim/lion.py @@ -9,30 +9,21 @@ class Lion(Optimizer1State): def __init__( self, params, - lr=1e-2, - alpha=0.99, - eps=1e-8, + lr=1e-4, + betas=(0.9, 0.99), weight_decay=0, - momentum=0, - centered=False, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, ): - if alpha == 0: - raise NotImplementedError( - "RMSprop with alpha==0.0 is not supported!" - ) - if centered: - raise NotImplementedError("Centered RMSprop is not supported!") super().__init__( "rmsprop", params, lr, - (alpha, momentum), - eps, + betas, + 0., weight_decay, optim_bits, args, @@ -46,29 +37,20 @@ class Lion8bit(Optimizer1State): def __init__( self, params, - lr=1e-2, - alpha=0.99, - eps=1e-8, + lr=1e-4, + betas=(0.9, 0.99), weight_decay=0, - momentum=0, - centered=False, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, ): - if alpha == 0: - raise NotImplementedError( - "RMSprop with alpha==0.0 is not supported!" - ) - if centered: - raise NotImplementedError("Centered RMSprop is not supported!") super().__init__( "rmsprop", params, lr, - (alpha, momentum), - eps, + betas, + 0., weight_decay, 8, args, @@ -82,30 +64,20 @@ class Lion32bit(Optimizer1State): def __init__( self, params, - lr=1e-2, - alpha=0.99, - eps=1e-8, + lr=1e-4, + betas=(0.9, 0.99), weight_decay=0, - momentum=0, - centered=False, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, ): - - if alpha == 0: - raise NotImplementedError( - "RMSprop with alpha==0.0 is not supported!" - ) - if centered: - raise NotImplementedError("Centered RMSprop is not supported!") super().__init__( "rmsprop", params, lr, - (alpha, momentum), - eps, + betas, + 0., weight_decay, 32, args,