From c83888aa1aab50fde54ccad19114e781e2fc62a4 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 9 Mar 2023 11:54:54 -0800 Subject: [PATCH] use epsilon as beta2 for lion, complete most of the logic in kernel.cu for all functions --- bitsandbytes/optim/lion.py | 17 ++++++++------- csrc/kernels.cu | 42 ++++++++++++++++++++++++++++---------- 2 files changed, 41 insertions(+), 18 deletions(-) diff --git a/bitsandbytes/optim/lion.py b/bitsandbytes/optim/lion.py index 4a00f57..81a9efe 100644 --- a/bitsandbytes/optim/lion.py +++ b/bitsandbytes/optim/lion.py @@ -18,12 +18,13 @@ class Lion(Optimizer1State): percentile_clipping=100, block_wise=True, ): + beta1, beta2 = betas super().__init__( "lion", params, lr, - betas, - 0., + (beta1, 0.), + beta2, weight_decay, optim_bits, args, @@ -44,13 +45,14 @@ class Lion8bit(Optimizer1State): min_8bit_size=4096, percentile_clipping=100, block_wise=True, - ): + ): + beta1, beta2 = betas super().__init__( "lion", params, lr, - betas, - 0., + (beta1, 0.), + beta2, weight_decay, 8, args, @@ -72,12 +74,13 @@ class Lion32bit(Optimizer1State): percentile_clipping=100, block_wise=True, ): + beta1, beta2 = betas super().__init__( "lion", params, lr, - betas, - 0., + (beta1, 0.), + beta2, weight_decay, 32, args, diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 76a8c73..553f884 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -43,6 +43,14 @@ __device__ float atomicMin(float* address, float val) { return __int_as_float(old); } +// sign function for lion +// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA + +template +__device__ int sgn(T val) { + return (T(0) < val) - (val < T(0)); +} + template __device__ unsigned char dQuantize(float* smem_code, const float rand, float x) { @@ -217,14 +225,6 @@ __device__ __forceinline__ unsigned char quantize_quadrant(int QUADRANT, float * } } -// sign function for lion -// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA - -template -__device__ int sgn(T val) { - return (T(0) < val) - (val < T(0)); -} - __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n) { const int tid = threadIdx.x + (blockDim.x*blockIdx.x); @@ -799,6 +799,10 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm break; case LION: + // using eps as beta2 + s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*(float)g_vals[j]); // state update + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value @@ -899,7 +903,11 @@ __global__ void kOptimizer32bit1State(T *g, T *p, p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); break; - case LION: + case LION: + // using eps as beta2 + p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])))); + s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*((float)g_vals[j])); + break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); @@ -1230,6 +1238,9 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c local_unorm += s1_vals[j]*s1_vals[j]; break; case LION: + // using eps as beta2 + s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*g_val); + break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); break; @@ -1333,6 +1344,10 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j])); break; case LION: + // using eps as beta2 + p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val)))); + s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*g_val); + break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps)); @@ -1676,7 +1691,10 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char else s1_vals[j] = (s1_vals[j]*beta1) + g_val; break; - case LION: + case LION: + // using eps as beta2 + s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*g_val); + break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); break; @@ -1714,7 +1732,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char case MOMENTUM: p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); break; - case LION: + case LION: + p_vals[j] = ((float)p_vals[j]) - lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))); + break; case RMSPROP: g_val = g_vals[j]; p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));