diff --git a/README.md b/README.md index d420e6c..dfd91cd 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ out = linear(x.to(torch.float16)) ## Features - 8-bit Matrix multiplication with mixed precision decomposition - LLM.int8() inference -- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB (saves 75% memory) +- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB, Lion (saves 75% memory) - Stable Embedding Layer: Improved stability through better initialization, and normalization - 8-bit quantization: Quantile, Linear, and Dynamic quantization - Fast quantile estimation: Up to 100x faster than other algorithms diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 95a7c4f..166e38f 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -35,6 +35,10 @@ if COMPILED_WITH_CUDA: lib.crmsprop32bit_g32, lib.crmsprop32bit_g16, ) + str2optimizer32bit["lion"] = ( + lib.clion32bit_g32, + lib.clion32bit_g16, + ) str2optimizer32bit["adagrad"] = ( lib.cadagrad32bit_g32, lib.cadagrad32bit_g16, @@ -58,6 +62,10 @@ if COMPILED_WITH_CUDA: lib.crmsprop_static_8bit_g32, lib.crmsprop_static_8bit_g16, ) + str2optimizer8bit["lion"] = ( + lib.clion_static_8bit_g32, + lib.clion_static_8bit_g16, + ) str2optimizer8bit["lamb"] = ( lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16, @@ -80,6 +88,10 @@ if COMPILED_WITH_CUDA: lib.crmsprop_8bit_blockwise_fp32, lib.crmsprop_8bit_blockwise_fp16, ) + str2optimizer8bit_blockwise["lion"] = ( + lib.clion_8bit_blockwise_fp32, + lib.clion_8bit_blockwise_fp16, + ) str2optimizer8bit_blockwise["adagrad"] = ( lib.cadagrad_8bit_blockwise_fp32, lib.cadagrad_8bit_blockwise_fp16, diff --git a/bitsandbytes/optim/lion.py b/bitsandbytes/optim/lion.py index a2fb6af..4a00f57 100644 --- a/bitsandbytes/optim/lion.py +++ b/bitsandbytes/optim/lion.py @@ -19,7 +19,7 @@ class Lion(Optimizer1State): block_wise=True, ): super().__init__( - "rmsprop", + "lion", params, lr, betas, @@ -46,7 +46,7 @@ class Lion8bit(Optimizer1State): block_wise=True, ): super().__init__( - "rmsprop", + "lion", params, lr, betas, @@ -73,7 +73,7 @@ class Lion32bit(Optimizer1State): block_wise=True, ): super().__init__( - "rmsprop", + "lion", params, lr, betas, diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 08b9b44..a871a55 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -790,6 +790,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm break; + case LION: case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value @@ -890,6 +891,7 @@ __global__ void kOptimizer32bit1State(T *g, T *p, p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); break; + case LION: case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); @@ -1219,6 +1221,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c if(unorm != NULL) local_unorm += s1_vals[j]*s1_vals[j]; break; + case LION: case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); break; @@ -1321,6 +1324,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j])); break; + case LION: case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps)); @@ -1664,6 +1668,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char else s1_vals[j] = (s1_vals[j]*beta1) + g_val; break; + case LION: case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); break; @@ -1701,6 +1706,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char case MOMENTUM: p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); break; + case LION: case RMSPROP: g_val = g_vals[j]; p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); @@ -2699,6 +2705,8 @@ MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half) MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) MAKE_PreconditionOptimizer32bit1State(RMSPROP, half) MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) +MAKE_PreconditionOptimizer32bit1State(LION, half) +MAKE_PreconditionOptimizer32bit1State(LION, float) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) @@ -2710,6 +2718,8 @@ MAKE_Optimizer32bit1State(MOMENTUM, half) MAKE_Optimizer32bit1State(MOMENTUM, float) MAKE_Optimizer32bit1State(RMSPROP, half) MAKE_Optimizer32bit1State(RMSPROP, float) +MAKE_Optimizer32bit1State(LION, half) +MAKE_Optimizer32bit1State(LION, float) MAKE_Optimizer32bit1State(ADAGRAD, half) MAKE_Optimizer32bit1State(ADAGRAD, float) @@ -2742,6 +2752,8 @@ MAKE_PreconditionStatic8bit1State(MOMENTUM, half) MAKE_PreconditionStatic8bit1State(MOMENTUM, float) MAKE_PreconditionStatic8bit1State(RMSPROP, half) MAKE_PreconditionStatic8bit1State(RMSPROP, float) +MAKE_PreconditionStatic8bit1State(LION, half) +MAKE_PreconditionStatic8bit1State(LION, float) #define MAKE_optimizerStatic8bit1State(oname, gtype) \ template __global__ void kOptimizerStatic8bit1State(gtype* p, gtype* const g, unsigned char* state1, \ @@ -2758,6 +2770,8 @@ MAKE_optimizerStatic8bit1State(MOMENTUM, half) MAKE_optimizerStatic8bit1State(MOMENTUM, float) MAKE_optimizerStatic8bit1State(RMSPROP, half) MAKE_optimizerStatic8bit1State(RMSPROP, float) +MAKE_optimizerStatic8bit1State(LION, half) +MAKE_optimizerStatic8bit1State(LION, float) #define MAKE_PreconditionStatic8bit2State(oname, gtype) \ template __global__ void kPreconditionOptimizerStatic8bit2State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \ @@ -2849,5 +2863,7 @@ MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8) diff --git a/csrc/ops.cu b/csrc/ops.cu index e770e10..cdd8a27 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -120,6 +120,7 @@ template void optimizer32bit(T* g, T* p, case MOMENTUM: case RMSPROP: case ADAGRAD: + case LION: if(max_unorm > 0.0f) { @@ -163,6 +164,7 @@ template void optimizerStatic8bit(T* p, T* g, case MOMENTUM: case RMSPROP: case ADAGRAD: + case LION: CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); @@ -198,6 +200,7 @@ template void optimizerStatic8bitBlockwise(T* p, T* g case MOMENTUM: case RMSPROP: case ADAGRAD: + case LION: num_blocks = n/BLOCKSIZE_1STATE; num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; kOptimizerStatic8bit1StateBlockwise<<>>(p, g, state1, beta1, beta2, eps, step, lr, @@ -707,6 +710,8 @@ MAKE_optimizer32bit(MOMENTUM, half) MAKE_optimizer32bit(MOMENTUM, float) MAKE_optimizer32bit(RMSPROP, half) MAKE_optimizer32bit(RMSPROP, float) +MAKE_optimizer32bit(LION, half) +MAKE_optimizer32bit(LION, float) MAKE_optimizer32bit(ADAGRAD, half) MAKE_optimizer32bit(ADAGRAD, float) @@ -726,6 +731,8 @@ MAKE_optimizerStatic8bit(MOMENTUM, half) MAKE_optimizerStatic8bit(MOMENTUM, float) MAKE_optimizerStatic8bit(RMSPROP, half) MAKE_optimizerStatic8bit(RMSPROP, float) +MAKE_optimizerStatic8bit(LION, half) +MAKE_optimizerStatic8bit(LION, float) #define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ template void optimizerStatic8bitBlockwise(gtype* p, gtype* g, \ @@ -738,6 +745,8 @@ MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(half, LION); +MAKE_optimizerStatic8bitBlockwise(float, LION); MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 31d4dd8..9f06435 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -70,6 +70,7 @@ typedef enum Optimizer_t RMSPROP = 2, LARS = 3, ADAGRAD = 4, + LION = 5, } Optimizer_t; typedef enum Transform_t diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index d8b2290..4caa7e8 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -33,6 +33,8 @@ MAKE_FUNC32(adam, ADAM, float, 32) MAKE_FUNC32(adam, ADAM, half, 16) MAKE_FUNC32(rmsprop, RMSPROP, float, 32) MAKE_FUNC32(rmsprop, RMSPROP, half, 16) +MAKE_FUNC32(lion, LION, float, 32) +MAKE_FUNC32(lion, LION, half, 16) MAKE_FUNC32(adagrad, ADAGRAD, float, 32) MAKE_FUNC32(adagrad, ADAGRAD, half, 16) @@ -55,6 +57,8 @@ MAKE_FUNC8(momentum, MOMENTUM, float, 32) MAKE_FUNC8(momentum, MOMENTUM, half, 16) MAKE_FUNC8(rmsprop, RMSPROP, float, 32) MAKE_FUNC8(rmsprop, RMSPROP, half, 16) +MAKE_FUNC8(lion, LION, float, 32) +MAKE_FUNC8(lion, LION, half, 16) #define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \ void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \ @@ -68,6 +72,8 @@ MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16) MAKE_BLOCKWISE8(momentum, MOMENTUM, float, 32) MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, 16) MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, 32) +MAKE_BLOCKWISE8(lion, LION, half, 16) +MAKE_BLOCKWISE8(lion, LION, float, 32) MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, 16) MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32) @@ -161,6 +167,8 @@ extern "C" MAKE_CFUNC32(momentum, half, 16) MAKE_CFUNC32(rmsprop, float, 32) MAKE_CFUNC32(rmsprop, half, 16) + MAKE_CFUNC32(lion, float, 32) + MAKE_CFUNC32(lion, half, 16) MAKE_CFUNC32(adagrad, float, 32) MAKE_CFUNC32(adagrad, half, 16) @@ -183,6 +191,8 @@ extern "C" MAKE_CFUNC8(momentum, half, 16) MAKE_CFUNC8(rmsprop, float, 32) MAKE_CFUNC8(rmsprop, half, 16) + MAKE_CFUNC8(lion, float, 32) + MAKE_CFUNC8(lion, half, 16) #define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \ @@ -196,6 +206,8 @@ extern "C" MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32) MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16) MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32) + MAKE_CBLOCKWISE8(lion, LION, half, 16) + MAKE_CBLOCKWISE8(lion, LION, float, 32) MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16) MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32) diff --git a/tests/test_optim.py b/tests/test_optim.py index 3df2dad..a11ba85 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -50,6 +50,10 @@ str2optimizers["rmsprop"] = ( lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False), ) +str2optimizers["rmsprop"] = ( + lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), + lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False), +) str2optimizers["adam8bit"] = ( torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False),