use epsilon as beta2 for lion, complete most of the logic in kernel.cu for all functions
This commit is contained in:
parent
64bb1ae8d1
commit
c83888aa1a
|
@ -18,12 +18,13 @@ class Lion(Optimizer1State):
|
|||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
beta1, beta2 = betas
|
||||
super().__init__(
|
||||
"lion",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
0.,
|
||||
(beta1, 0.),
|
||||
beta2,
|
||||
weight_decay,
|
||||
optim_bits,
|
||||
args,
|
||||
|
@ -45,12 +46,13 @@ class Lion8bit(Optimizer1State):
|
|||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
beta1, beta2 = betas
|
||||
super().__init__(
|
||||
"lion",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
0.,
|
||||
(beta1, 0.),
|
||||
beta2,
|
||||
weight_decay,
|
||||
8,
|
||||
args,
|
||||
|
@ -72,12 +74,13 @@ class Lion32bit(Optimizer1State):
|
|||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
beta1, beta2 = betas
|
||||
super().__init__(
|
||||
"lion",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
0.,
|
||||
(beta1, 0.),
|
||||
beta2,
|
||||
weight_decay,
|
||||
32,
|
||||
args,
|
||||
|
|
|
@ -43,6 +43,14 @@ __device__ float atomicMin(float* address, float val) {
|
|||
return __int_as_float(old);
|
||||
}
|
||||
|
||||
// sign function for lion
|
||||
// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA
|
||||
|
||||
template <typename T>
|
||||
__device__ int sgn(T val) {
|
||||
return (T(0) < val) - (val < T(0));
|
||||
}
|
||||
|
||||
template <int STOCHASTIC>
|
||||
__device__ unsigned char dQuantize(float* smem_code, const float rand, float x)
|
||||
{
|
||||
|
@ -217,14 +225,6 @@ __device__ __forceinline__ unsigned char quantize_quadrant(int QUADRANT, float *
|
|||
}
|
||||
}
|
||||
|
||||
// sign function for lion
|
||||
// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA
|
||||
|
||||
template <typename T>
|
||||
__device__ int sgn(T val) {
|
||||
return (T(0) < val) - (val < T(0));
|
||||
}
|
||||
|
||||
__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n)
|
||||
{
|
||||
const int tid = threadIdx.x + (blockDim.x*blockIdx.x);
|
||||
|
@ -799,6 +799,10 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
|
|||
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
|
||||
break;
|
||||
case LION:
|
||||
// using eps as beta2
|
||||
s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*(float)g_vals[j]); // state update
|
||||
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
|
||||
break;
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update
|
||||
s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
|
||||
|
@ -900,6 +904,10 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
|
|||
p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
|
||||
break;
|
||||
case LION:
|
||||
// using eps as beta2
|
||||
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))));
|
||||
s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*((float)g_vals[j]));
|
||||
break;
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
|
||||
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
|
||||
|
@ -1230,6 +1238,9 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
|
|||
local_unorm += s1_vals[j]*s1_vals[j];
|
||||
break;
|
||||
case LION:
|
||||
// using eps as beta2
|
||||
s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*g_val);
|
||||
break;
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
|
||||
break;
|
||||
|
@ -1333,6 +1344,10 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
|
|||
p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j]));
|
||||
break;
|
||||
case LION:
|
||||
// using eps as beta2
|
||||
p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val))));
|
||||
s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*g_val);
|
||||
break;
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
|
||||
p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps));
|
||||
|
@ -1677,6 +1692,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
s1_vals[j] = (s1_vals[j]*beta1) + g_val;
|
||||
break;
|
||||
case LION:
|
||||
// using eps as beta2
|
||||
s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*g_val);
|
||||
break;
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
|
||||
break;
|
||||
|
@ -1715,6 +1733,8 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
|
||||
break;
|
||||
case LION:
|
||||
p_vals[j] = ((float)p_vals[j]) - lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])));
|
||||
break;
|
||||
case RMSPROP:
|
||||
g_val = g_vals[j];
|
||||
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
|
||||
|
|
Loading…
Reference in New Issue
Block a user