diff --git a/README.md b/README.md index d420e6c..dfd91cd 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ out = linear(x.to(torch.float16)) ## Features - 8-bit Matrix multiplication with mixed precision decomposition - LLM.int8() inference -- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB (saves 75% memory) +- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB, Lion (saves 75% memory) - Stable Embedding Layer: Improved stability through better initialization, and normalization - 8-bit quantization: Quantile, Linear, and Dynamic quantization - Fast quantile estimation: Up to 100x faster than other algorithms diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 3f7b328..9840b47 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -35,6 +35,10 @@ if COMPILED_WITH_CUDA: lib.crmsprop32bit_g32, lib.crmsprop32bit_g16, ) + str2optimizer32bit["lion"] = ( + lib.clion32bit_g32, + lib.clion32bit_g16, + ) str2optimizer32bit["adagrad"] = ( lib.cadagrad32bit_g32, lib.cadagrad32bit_g16, @@ -58,6 +62,10 @@ if COMPILED_WITH_CUDA: lib.crmsprop_static_8bit_g32, lib.crmsprop_static_8bit_g16, ) + str2optimizer8bit["lion"] = ( + lib.clion_static_8bit_g32, + lib.clion_static_8bit_g16, + ) str2optimizer8bit["lamb"] = ( lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16, @@ -80,6 +88,10 @@ if COMPILED_WITH_CUDA: lib.crmsprop_8bit_blockwise_fp32, lib.crmsprop_8bit_blockwise_fp16, ) + str2optimizer8bit_blockwise["lion"] = ( + lib.clion_8bit_blockwise_fp32, + lib.clion_8bit_blockwise_fp16, + ) str2optimizer8bit_blockwise["adagrad"] = ( lib.cadagrad_8bit_blockwise_fp32, lib.cadagrad_8bit_blockwise_fp16, diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py index 8c8a8f4..53533ee 100644 --- a/bitsandbytes/optim/__init__.py +++ b/bitsandbytes/optim/__init__.py @@ -12,4 +12,5 @@ from .lamb import LAMB, LAMB8bit, LAMB32bit from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS from .optimizer import GlobalOptimManager from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit +from .lion import Lion, Lion8bit, Lion32bit from .sgd import SGD, SGD8bit, SGD32bit diff --git a/bitsandbytes/optim/lion.py b/bitsandbytes/optim/lion.py new file mode 100644 index 0000000..2551b68 --- /dev/null +++ b/bitsandbytes/optim/lion.py @@ -0,0 +1,87 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from bitsandbytes.optim.optimizer import Optimizer1State + + +class Lion(Optimizer1State): + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super().__init__( + "lion", + params, + lr, + betas, + 0., + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + + +class Lion8bit(Optimizer1State): + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super().__init__( + "lion", + params, + lr, + betas, + 0., + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + + +class Lion32bit(Optimizer1State): + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super().__init__( + "lion", + params, + lr, + betas, + 0., + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 08b9b44..e0df802 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -43,6 +43,14 @@ __device__ float atomicMin(float* address, float val) { return __int_as_float(old); } +// sign function for lion +// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA + +template +__device__ int sgn(T val) { + return (T(0) < val) - (val < T(0)); +} + template __device__ unsigned char dQuantize(float* smem_code, const float rand, float x) { @@ -743,7 +751,7 @@ template __launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, - const float beta1, const float eps, const float weight_decay, + const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n) { @@ -790,6 +798,9 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm break; + case LION: + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*(float)g_vals[j]); // state update + break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value @@ -821,7 +832,7 @@ template __launch_bounds__(TH, 1) __global__ void kOptimizer32bit1State(T *g, T *p, float *state1, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float eps, const float weight_decay, + const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) { @@ -890,6 +901,10 @@ __global__ void kOptimizer32bit1State(T *g, T *p, p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); break; + case LION: + p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])))); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j])); + break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); @@ -1158,7 +1173,7 @@ __global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, float *unorm, - const float beta1, + const float beta1, const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, float* max1, float* new_max1, @@ -1219,6 +1234,9 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c if(unorm != NULL) local_unorm += s1_vals[j]*s1_vals[j]; break; + case LION: + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); break; @@ -1244,7 +1262,7 @@ template __global__ void kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, const float *unorm, const float max_unorm, const float param_norm, - const float beta1, + const float beta1, const float beta2, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* max1, float* new_max1, @@ -1307,8 +1325,19 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, { g_val = float(g_vals[j]); g_val *= gnorm_scale; - if(weight_decay > 0.0f) - g_val += ((float)p_vals[j])*weight_decay; + + if(weight_decay > 0.0f) { + switch(OPTIMIZER) { + case MOMENTUM: + case RMSPROP: + g_val += ((float)p_vals[j])*weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); + break; + } + } + s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0]; switch(OPTIMIZER) @@ -1321,6 +1350,10 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j])); break; + case LION: + p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val)))); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps)); @@ -1649,10 +1682,20 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char { g_val = float(g_vals[j]); g_val *= gnorm_scale; - if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) - { - if(weight_decay > 0.0f) - g_val += ((float)p_vals[j])*weight_decay; + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + if(weight_decay > 0.0f) { + switch(OPTIMIZER) { + case MOMENTUM: + case ADAGRAD: + case RMSPROP: + g_val += ((float)p_vals[j])*weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); + break; + } + } s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; @@ -1664,6 +1707,11 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char else s1_vals[j] = (s1_vals[j]*beta1) + g_val; break; + case LION: + // here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, before the momentum is updated by beta2 + g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val)); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); break; @@ -1701,6 +1749,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char case MOMENTUM: p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); break; + case LION: + p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]); + break; case RMSPROP: g_val = g_vals[j]; p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); @@ -2692,24 +2743,28 @@ template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *c #define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ template __global__ void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ float* state1, float *unorm, \ - const float beta1, const float eps, const float weight_decay, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ const int step, const float lr, const float gnorm_scale, const int n); \ MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half) MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) MAKE_PreconditionOptimizer32bit1State(RMSPROP, half) MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) +MAKE_PreconditionOptimizer32bit1State(LION, half) +MAKE_PreconditionOptimizer32bit1State(LION, float) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) #define MAKE_Optimizer32bit1State(oname, gtype) \ template __global__ void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ - const float beta1, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \ + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \ MAKE_Optimizer32bit1State(MOMENTUM, half) MAKE_Optimizer32bit1State(MOMENTUM, float) MAKE_Optimizer32bit1State(RMSPROP, half) MAKE_Optimizer32bit1State(RMSPROP, float) +MAKE_Optimizer32bit1State(LION, half) +MAKE_Optimizer32bit1State(LION, float) MAKE_Optimizer32bit1State(ADAGRAD, half) MAKE_Optimizer32bit1State(ADAGRAD, float) @@ -2731,6 +2786,7 @@ template __global__ void kOptimizer32bit2State(float* g, float* p, template __global__ void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ float *unorm, \ const float beta1, \ + const float beta2, \ const float eps, const int step, \ float* __restrict__ const quantiles1, \ float* max1, float* new_max1, \ @@ -2742,11 +2798,14 @@ MAKE_PreconditionStatic8bit1State(MOMENTUM, half) MAKE_PreconditionStatic8bit1State(MOMENTUM, float) MAKE_PreconditionStatic8bit1State(RMSPROP, half) MAKE_PreconditionStatic8bit1State(RMSPROP, float) +MAKE_PreconditionStatic8bit1State(LION, half) +MAKE_PreconditionStatic8bit1State(LION, float) #define MAKE_optimizerStatic8bit1State(oname, gtype) \ template __global__ void kOptimizerStatic8bit1State(gtype* p, gtype* const g, unsigned char* state1, \ const float *unorm, const float max_unorm, const float param_norm, \ const float beta1, \ + const float beta2, \ const float eps, const int step, const float lr, \ float* __restrict__ const quantiles1, \ float* max1, float* new_max1, \ @@ -2758,6 +2817,8 @@ MAKE_optimizerStatic8bit1State(MOMENTUM, half) MAKE_optimizerStatic8bit1State(MOMENTUM, float) MAKE_optimizerStatic8bit1State(RMSPROP, half) MAKE_optimizerStatic8bit1State(RMSPROP, float) +MAKE_optimizerStatic8bit1State(LION, half) +MAKE_optimizerStatic8bit1State(LION, float) #define MAKE_PreconditionStatic8bit2State(oname, gtype) \ template __global__ void kPreconditionOptimizerStatic8bit2State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \ @@ -2849,5 +2910,7 @@ MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8) diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index d90ea13..a8aa3fc 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -32,20 +32,20 @@ __global__ void kOptimizer32bit2State(T* g, T* p, template __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, - const float beta1, const float eps, const float weight_decay, + const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n); template __global__ void kOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float eps, const float weight_decay, + const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, float *unorm, - const float beta1, + const float beta1, const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, float* max1, float* new_max1, @@ -57,7 +57,7 @@ template __global__ void kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, const float *unorm, const float max_unorm, const float param_norm, - const float beta1, + const float beta1, const float beta2, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* max1, float* new_max1, diff --git a/csrc/ops.cu b/csrc/ops.cu index e770e10..94d5f2e 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -120,17 +120,28 @@ template void optimizer32bit(T* g, T* p, case MOMENTUM: case RMSPROP: case ADAGRAD: - if(max_unorm > 0.0f) { CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); - kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n); + kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } - kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; + case LION: + // in lion, the momentum update after the parameter update + kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); + kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + } + break; } } @@ -164,12 +175,22 @@ template void optimizerStatic8bit(T* p, T* g, case RMSPROP: case ADAGRAD: CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); - kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); - kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr, + kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; + case LION: + // in lion, the momentum update happens after the parameter update + kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + + CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); + kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; default: break; } @@ -198,6 +219,7 @@ template void optimizerStatic8bitBlockwise(T* p, T* g case MOMENTUM: case RMSPROP: case ADAGRAD: + case LION: num_blocks = n/BLOCKSIZE_1STATE; num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; kOptimizerStatic8bit1StateBlockwise<<>>(p, g, state1, beta1, beta2, eps, step, lr, @@ -707,6 +729,8 @@ MAKE_optimizer32bit(MOMENTUM, half) MAKE_optimizer32bit(MOMENTUM, float) MAKE_optimizer32bit(RMSPROP, half) MAKE_optimizer32bit(RMSPROP, float) +MAKE_optimizer32bit(LION, half) +MAKE_optimizer32bit(LION, float) MAKE_optimizer32bit(ADAGRAD, half) MAKE_optimizer32bit(ADAGRAD, float) @@ -726,6 +750,8 @@ MAKE_optimizerStatic8bit(MOMENTUM, half) MAKE_optimizerStatic8bit(MOMENTUM, float) MAKE_optimizerStatic8bit(RMSPROP, half) MAKE_optimizerStatic8bit(RMSPROP, float) +MAKE_optimizerStatic8bit(LION, half) +MAKE_optimizerStatic8bit(LION, float) #define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ template void optimizerStatic8bitBlockwise(gtype* p, gtype* g, \ @@ -738,6 +764,8 @@ MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(half, LION); +MAKE_optimizerStatic8bitBlockwise(float, LION); MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 31d4dd8..9f06435 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -70,6 +70,7 @@ typedef enum Optimizer_t RMSPROP = 2, LARS = 3, ADAGRAD = 4, + LION = 5, } Optimizer_t; typedef enum Transform_t diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index d8b2290..4caa7e8 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -33,6 +33,8 @@ MAKE_FUNC32(adam, ADAM, float, 32) MAKE_FUNC32(adam, ADAM, half, 16) MAKE_FUNC32(rmsprop, RMSPROP, float, 32) MAKE_FUNC32(rmsprop, RMSPROP, half, 16) +MAKE_FUNC32(lion, LION, float, 32) +MAKE_FUNC32(lion, LION, half, 16) MAKE_FUNC32(adagrad, ADAGRAD, float, 32) MAKE_FUNC32(adagrad, ADAGRAD, half, 16) @@ -55,6 +57,8 @@ MAKE_FUNC8(momentum, MOMENTUM, float, 32) MAKE_FUNC8(momentum, MOMENTUM, half, 16) MAKE_FUNC8(rmsprop, RMSPROP, float, 32) MAKE_FUNC8(rmsprop, RMSPROP, half, 16) +MAKE_FUNC8(lion, LION, float, 32) +MAKE_FUNC8(lion, LION, half, 16) #define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \ void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \ @@ -68,6 +72,8 @@ MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16) MAKE_BLOCKWISE8(momentum, MOMENTUM, float, 32) MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, 16) MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, 32) +MAKE_BLOCKWISE8(lion, LION, half, 16) +MAKE_BLOCKWISE8(lion, LION, float, 32) MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, 16) MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32) @@ -161,6 +167,8 @@ extern "C" MAKE_CFUNC32(momentum, half, 16) MAKE_CFUNC32(rmsprop, float, 32) MAKE_CFUNC32(rmsprop, half, 16) + MAKE_CFUNC32(lion, float, 32) + MAKE_CFUNC32(lion, half, 16) MAKE_CFUNC32(adagrad, float, 32) MAKE_CFUNC32(adagrad, half, 16) @@ -183,6 +191,8 @@ extern "C" MAKE_CFUNC8(momentum, half, 16) MAKE_CFUNC8(rmsprop, float, 32) MAKE_CFUNC8(rmsprop, half, 16) + MAKE_CFUNC8(lion, float, 32) + MAKE_CFUNC8(lion, half, 16) #define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \ @@ -196,6 +206,8 @@ extern "C" MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32) MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16) MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32) + MAKE_CBLOCKWISE8(lion, LION, half, 16) + MAKE_CBLOCKWISE8(lion, LION, float, 32) MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16) MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32) diff --git a/requirements.txt b/requirements.txt index e079f8a..883b2e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ +lion-pytorch pytest diff --git a/tests/test_optim.py b/tests/test_optim.py index 3df2dad..9f815ab 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -7,6 +7,8 @@ from itertools import product from os.path import join import pytest +from lion_pytorch import Lion + import torch import bitsandbytes as bnb @@ -31,6 +33,7 @@ str2optimizers = {} str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam) # str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam) # str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam) +str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion) str2optimizers["momentum_pytorch"] = ( None, lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), @@ -38,6 +41,7 @@ str2optimizers["momentum_pytorch"] = ( ) str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam) # str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) +str2optimizers["lion"] = (Lion, bnb.optim.Lion) str2optimizers["momentum"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False), @@ -54,6 +58,10 @@ str2optimizers["adam8bit"] = ( torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False), ) +str2optimizers["lion8bit"] = ( + Lion, + lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False), +) str2optimizers["momentum8bit"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False), @@ -71,6 +79,10 @@ str2optimizers["adam8bit_blockwise"] = ( torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True), ) +str2optimizers["lion8bit_blockwise"] = ( + Lion, + lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True), +) str2optimizers["momentum8bit_blockwise"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True), @@ -82,6 +94,7 @@ str2optimizers["rmsprop8bit_blockwise"] = ( str2statenames = {} str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] +str2statenames["lion"] = [("exp_avg", "state1")] str2statenames["momentum"] = [("momentum_buffer", "state1")] str2statenames["lars"] = [("momentum_buffer", "state1")] str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] @@ -90,6 +103,9 @@ str2statenames["adam8bit"] = [ ("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2"), ] +str2statenames["lion8bit"] = [ + ("exp_avg", "state1", "qmap1", "max1") +] str2statenames["lamb8bit"] = [ ("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2"), @@ -98,6 +114,9 @@ str2statenames["adam8bit_blockwise"] = [ ("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2"), ] +str2statenames["lion8bit_blockwise"] = [ + ("exp_avg", "state1", "qmap1", "absmax1") +] str2statenames["momentum8bit"] = [ ("momentum_buffer", "state1", "qmap1", "max1") ] @@ -113,7 +132,7 @@ str2statenames["rmsprop8bit_blockwise"] = [ dim1 = [1024] dim2 = [32, 1024, 4097, 1] gtype = [torch.float32, torch.float16] -optimizer_names = ["adam", "momentum", "rmsprop", "lars"] +optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lion"] values = list(product(dim1, dim2, gtype, optimizer_names)) names = [ "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values @@ -241,9 +260,11 @@ dim2 = [32, 1024, 4097] gtype = [torch.float32, torch.float16] optimizer_names = [ "adam8bit", + "lion8bit", "momentum8bit", "rmsprop8bit", "adam8bit_blockwise", + "lion8bit_blockwise", "lars8bit", "momentum8bit_blockwise", "rmsprop8bit_blockwise",