follow advice of Tim to fix update of momentum vs parameters in blockwise 8 bit
This commit is contained in:
parent
369a51c432
commit
9b656f461a
|
@ -1708,6 +1708,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
||||||
s1_vals[j] = (s1_vals[j]*beta1) + g_val;
|
s1_vals[j] = (s1_vals[j]*beta1) + g_val;
|
||||||
break;
|
break;
|
||||||
case LION:
|
case LION:
|
||||||
|
g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])));
|
||||||
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
|
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
|
||||||
break;
|
break;
|
||||||
case RMSPROP:
|
case RMSPROP:
|
||||||
|
@ -1748,7 +1749,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
||||||
p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
|
p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
|
||||||
break;
|
break;
|
||||||
case LION:
|
case LION:
|
||||||
p_vals[j] = ((float)p_vals[j]) - lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])));
|
p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]);
|
||||||
break;
|
break;
|
||||||
case RMSPROP:
|
case RMSPROP:
|
||||||
g_val = g_vals[j];
|
g_val = g_vals[j];
|
||||||
|
|
Loading…
Reference in New Issue
Block a user