switch all eps to beta2
This commit is contained in:
parent
6c377b39b6
commit
369a51c432
|
@ -799,8 +799,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
|
|||
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
|
||||
break;
|
||||
case LION:
|
||||
// using eps as beta2
|
||||
s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*(float)g_vals[j]); // state update
|
||||
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*(float)g_vals[j]); // state update
|
||||
break;
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update
|
||||
|
@ -903,9 +902,8 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
|
|||
p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
|
||||
break;
|
||||
case LION:
|
||||
// using eps as beta2
|
||||
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))));
|
||||
s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*((float)g_vals[j]));
|
||||
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j]));
|
||||
break;
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
|
||||
|
@ -1237,7 +1235,6 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
|
|||
local_unorm += s1_vals[j]*s1_vals[j];
|
||||
break;
|
||||
case LION:
|
||||
// using eps as beta2
|
||||
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
|
||||
break;
|
||||
case RMSPROP:
|
||||
|
@ -1354,7 +1351,6 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
|
|||
p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j]));
|
||||
break;
|
||||
case LION:
|
||||
// using eps as beta2
|
||||
p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val))));
|
||||
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
|
||||
break;
|
||||
|
|
Loading…
Reference in New Issue
Block a user