do the epsilon beta2 switcharoo within the cuda code, and not within the python class (so that the state dict still makes sense)
This commit is contained in:
parent
8618bed001
commit
c99b44f774
|
@ -18,13 +18,12 @@ class Lion(Optimizer1State):
|
|||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
beta1, beta2 = betas
|
||||
super().__init__(
|
||||
"lion",
|
||||
params,
|
||||
lr,
|
||||
(beta1, 0.),
|
||||
beta2,
|
||||
(beta1, beta2),
|
||||
0.,
|
||||
weight_decay,
|
||||
optim_bits,
|
||||
args,
|
||||
|
@ -46,13 +45,12 @@ class Lion8bit(Optimizer1State):
|
|||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
beta1, beta2 = betas
|
||||
super().__init__(
|
||||
"lion",
|
||||
params,
|
||||
lr,
|
||||
(beta1, 0.),
|
||||
beta2,
|
||||
(beta1, beta2),
|
||||
0.,
|
||||
weight_decay,
|
||||
8,
|
||||
args,
|
||||
|
@ -74,13 +72,12 @@ class Lion32bit(Optimizer1State):
|
|||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
beta1, beta2 = betas
|
||||
super().__init__(
|
||||
"lion",
|
||||
params,
|
||||
lr,
|
||||
(beta1, 0.),
|
||||
beta2,
|
||||
betas,
|
||||
0.,
|
||||
weight_decay,
|
||||
32,
|
||||
args,
|
||||
|
|
|
@ -132,13 +132,13 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
|||
break;
|
||||
case LION:
|
||||
// in lion, the momentum update after the parameter update
|
||||
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
||||
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
|
||||
if(max_unorm > 0.0f)
|
||||
{
|
||||
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, weight_decay, step, lr, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
break;
|
||||
|
@ -183,12 +183,12 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
|
|||
break;
|
||||
case LION:
|
||||
// in lion, the momentum update happens after the parameter update
|
||||
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
|
||||
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, step, lr,
|
||||
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
|
||||
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
default:
|
||||
|
|
Loading…
Reference in New Issue
Block a user