fix weight decay for lion to be decoupled, using a switch
This commit is contained in:
parent
ead570a43e
commit
af03430992
|
@ -1328,8 +1328,19 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
|
||||||
{
|
{
|
||||||
g_val = float(g_vals[j]);
|
g_val = float(g_vals[j]);
|
||||||
g_val *= gnorm_scale;
|
g_val *= gnorm_scale;
|
||||||
if(weight_decay > 0.0f)
|
|
||||||
g_val += ((float)p_vals[j])*weight_decay;
|
if(weight_decay > 0.0f) {
|
||||||
|
switch(OPTIMIZER) {
|
||||||
|
case MOMENTUM:
|
||||||
|
case RMSPROP:
|
||||||
|
g_val += ((float)p_vals[j])*weight_decay;
|
||||||
|
break;
|
||||||
|
case LION:
|
||||||
|
p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0];
|
s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0];
|
||||||
|
|
||||||
switch(OPTIMIZER)
|
switch(OPTIMIZER)
|
||||||
|
@ -1677,8 +1688,17 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
||||||
g_val *= gnorm_scale;
|
g_val *= gnorm_scale;
|
||||||
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
|
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
|
||||||
{
|
{
|
||||||
if(weight_decay > 0.0f)
|
if(weight_decay > 0.0f) {
|
||||||
g_val += ((float)p_vals[j])*weight_decay;
|
switch(OPTIMIZER) {
|
||||||
|
case MOMENTUM:
|
||||||
|
case RMSPROP:
|
||||||
|
g_val += ((float)p_vals[j])*weight_decay;
|
||||||
|
break;
|
||||||
|
case LION:
|
||||||
|
p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
|
s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user