use epsilon as beta2 for lion, complete most of the logic in kernel.cu for all functions

This commit is contained in:
Phil Wang 2023-03-09 11:54:54 -08:00
parent 64bb1ae8d1
commit c83888aa1a
2 changed files with 41 additions and 18 deletions

View File

@ -18,12 +18,13 @@ class Lion(Optimizer1State):
percentile_clipping=100,
block_wise=True,
):
beta1, beta2 = betas
super().__init__(
"lion",
params,
lr,
betas,
0.,
(beta1, 0.),
beta2,
weight_decay,
optim_bits,
args,
@ -44,13 +45,14 @@ class Lion8bit(Optimizer1State):
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
):
beta1, beta2 = betas
super().__init__(
"lion",
params,
lr,
betas,
0.,
(beta1, 0.),
beta2,
weight_decay,
8,
args,
@ -72,12 +74,13 @@ class Lion32bit(Optimizer1State):
percentile_clipping=100,
block_wise=True,
):
beta1, beta2 = betas
super().__init__(
"lion",
params,
lr,
betas,
0.,
(beta1, 0.),
beta2,
weight_decay,
32,
args,

View File

@ -43,6 +43,14 @@ __device__ float atomicMin(float* address, float val) {
return __int_as_float(old);
}
// sign function for lion
// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA
template <typename T>
__device__ int sgn(T val) {
return (T(0) < val) - (val < T(0));
}
template <int STOCHASTIC>
__device__ unsigned char dQuantize(float* smem_code, const float rand, float x)
{
@ -217,14 +225,6 @@ __device__ __forceinline__ unsigned char quantize_quadrant(int QUADRANT, float *
}
}
// sign function for lion
// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA
template <typename T>
__device__ int sgn(T val) {
return (T(0) < val) - (val < T(0));
}
__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n)
{
const int tid = threadIdx.x + (blockDim.x*blockIdx.x);
@ -799,6 +799,10 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
break;
case LION:
// using eps as beta2
s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*(float)g_vals[j]); // state update
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update
s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
@ -899,7 +903,11 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
break;
case LION:
case LION:
// using eps as beta2
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))));
s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*((float)g_vals[j]));
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
@ -1230,6 +1238,9 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
local_unorm += s1_vals[j]*s1_vals[j];
break;
case LION:
// using eps as beta2
s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*g_val);
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
break;
@ -1333,6 +1344,10 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j]));
break;
case LION:
// using eps as beta2
p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val))));
s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*g_val);
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps));
@ -1676,7 +1691,10 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
else
s1_vals[j] = (s1_vals[j]*beta1) + g_val;
break;
case LION:
case LION:
// using eps as beta2
s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*g_val);
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
break;
@ -1714,7 +1732,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
case MOMENTUM:
p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
break;
case LION:
case LION:
p_vals[j] = ((float)p_vals[j]) - lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])));
break;
case RMSPROP:
g_val = g_vals[j];
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));