do the epsilon beta2 switcharoo within the cuda code, and not within the python class (so that the state dict still makes sense)

This commit is contained in:
Phil Wang 2023-03-10 08:57:59 -08:00
parent 8618bed001
commit c99b44f774
2 changed files with 10 additions and 13 deletions

View File

@ -18,13 +18,12 @@ class Lion(Optimizer1State):
percentile_clipping=100,
block_wise=True,
):
beta1, beta2 = betas
super().__init__(
"lion",
params,
lr,
(beta1, 0.),
beta2,
(beta1, beta2),
0.,
weight_decay,
optim_bits,
args,
@ -46,13 +45,12 @@ class Lion8bit(Optimizer1State):
percentile_clipping=100,
block_wise=True,
):
beta1, beta2 = betas
super().__init__(
"lion",
params,
lr,
(beta1, 0.),
beta2,
(beta1, beta2),
0.,
weight_decay,
8,
args,
@ -74,13 +72,12 @@ class Lion32bit(Optimizer1State):
percentile_clipping=100,
block_wise=True,
):
beta1, beta2 = betas
super().__init__(
"lion",
params,
lr,
(beta1, 0.),
beta2,
betas,
0.,
weight_decay,
32,
args,

View File

@ -132,13 +132,13 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
break;
case LION:
// in lion, the momentum update after the parameter update
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
if(max_unorm > 0.0f)
{
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
break;
@ -183,12 +183,12 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
break;
case LION:
// in lion, the momentum update happens after the parameter update
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
default: