fix weight decay for lion to be decoupled, using a switch
This commit is contained in:
parent
ead570a43e
commit
af03430992
|
@ -1328,8 +1328,19 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
|
|||
{
|
||||
g_val = float(g_vals[j]);
|
||||
g_val *= gnorm_scale;
|
||||
if(weight_decay > 0.0f)
|
||||
g_val += ((float)p_vals[j])*weight_decay;
|
||||
|
||||
if(weight_decay > 0.0f) {
|
||||
switch(OPTIMIZER) {
|
||||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
g_val += ((float)p_vals[j])*weight_decay;
|
||||
break;
|
||||
case LION:
|
||||
p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0];
|
||||
|
||||
switch(OPTIMIZER)
|
||||
|
@ -1677,8 +1688,17 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
g_val *= gnorm_scale;
|
||||
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
|
||||
{
|
||||
if(weight_decay > 0.0f)
|
||||
g_val += ((float)p_vals[j])*weight_decay;
|
||||
if(weight_decay > 0.0f) {
|
||||
switch(OPTIMIZER) {
|
||||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
g_val += ((float)p_vals[j])*weight_decay;
|
||||
break;
|
||||
case LION:
|
||||
p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user