diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 98a3188..3c3445f 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1708,6 +1708,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char s1_vals[j] = (s1_vals[j]*beta1) + g_val; break; case LION: + g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))); s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); break; case RMSPROP: @@ -1748,7 +1749,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); break; case LION: - p_vals[j] = ((float)p_vals[j]) - lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))); + p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]); break; case RMSPROP: g_val = g_vals[j];