do the epsilon beta2 switcharoo within the cuda code, and not within the python class (so that the state dict still makes sense)
This commit is contained in:
parent
8618bed001
commit
c99b44f774
|
@ -18,13 +18,12 @@ class Lion(Optimizer1State):
|
||||||
percentile_clipping=100,
|
percentile_clipping=100,
|
||||||
block_wise=True,
|
block_wise=True,
|
||||||
):
|
):
|
||||||
beta1, beta2 = betas
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
"lion",
|
"lion",
|
||||||
params,
|
params,
|
||||||
lr,
|
lr,
|
||||||
(beta1, 0.),
|
(beta1, beta2),
|
||||||
beta2,
|
0.,
|
||||||
weight_decay,
|
weight_decay,
|
||||||
optim_bits,
|
optim_bits,
|
||||||
args,
|
args,
|
||||||
|
@ -46,13 +45,12 @@ class Lion8bit(Optimizer1State):
|
||||||
percentile_clipping=100,
|
percentile_clipping=100,
|
||||||
block_wise=True,
|
block_wise=True,
|
||||||
):
|
):
|
||||||
beta1, beta2 = betas
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
"lion",
|
"lion",
|
||||||
params,
|
params,
|
||||||
lr,
|
lr,
|
||||||
(beta1, 0.),
|
(beta1, beta2),
|
||||||
beta2,
|
0.,
|
||||||
weight_decay,
|
weight_decay,
|
||||||
8,
|
8,
|
||||||
args,
|
args,
|
||||||
|
@ -74,13 +72,12 @@ class Lion32bit(Optimizer1State):
|
||||||
percentile_clipping=100,
|
percentile_clipping=100,
|
||||||
block_wise=True,
|
block_wise=True,
|
||||||
):
|
):
|
||||||
beta1, beta2 = betas
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
"lion",
|
"lion",
|
||||||
params,
|
params,
|
||||||
lr,
|
lr,
|
||||||
(beta1, 0.),
|
betas,
|
||||||
beta2,
|
0.,
|
||||||
weight_decay,
|
weight_decay,
|
||||||
32,
|
32,
|
||||||
args,
|
args,
|
||||||
|
|
|
@ -132,13 +132,13 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
||||||
break;
|
break;
|
||||||
case LION:
|
case LION:
|
||||||
// in lion, the momentum update after the parameter update
|
// in lion, the momentum update after the parameter update
|
||||||
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
||||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
|
||||||
if(max_unorm > 0.0f)
|
if(max_unorm > 0.0f)
|
||||||
{
|
{
|
||||||
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
|
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
|
||||||
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
|
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, weight_decay, step, lr, gnorm_scale, n);
|
||||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
@ -183,12 +183,12 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
|
||||||
break;
|
break;
|
||||||
case LION:
|
case LION:
|
||||||
// in lion, the momentum update happens after the parameter update
|
// in lion, the momentum update happens after the parameter update
|
||||||
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
|
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, step, lr,
|
||||||
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
|
||||||
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
|
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
|
||||||
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user