diff --git a/bitsandbytes/optim/lion.py b/bitsandbytes/optim/lion.py index 81a9efe..c267af7 100644 --- a/bitsandbytes/optim/lion.py +++ b/bitsandbytes/optim/lion.py @@ -18,13 +18,12 @@ class Lion(Optimizer1State): percentile_clipping=100, block_wise=True, ): - beta1, beta2 = betas super().__init__( "lion", params, lr, - (beta1, 0.), - beta2, + (beta1, beta2), + 0., weight_decay, optim_bits, args, @@ -46,13 +45,12 @@ class Lion8bit(Optimizer1State): percentile_clipping=100, block_wise=True, ): - beta1, beta2 = betas super().__init__( "lion", params, lr, - (beta1, 0.), - beta2, + (beta1, beta2), + 0., weight_decay, 8, args, @@ -74,13 +72,12 @@ class Lion32bit(Optimizer1State): percentile_clipping=100, block_wise=True, ): - beta1, beta2 = betas super().__init__( "lion", params, lr, - (beta1, 0.), - beta2, + betas, + 0., weight_decay, 32, args, diff --git a/csrc/ops.cu b/csrc/ops.cu index 384aff7..51c530e 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -132,13 +132,13 @@ template void optimizer32bit(T* g, T* p, break; case LION: // in lion, the momentum update after the parameter update - kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); if(max_unorm > 0.0f) { CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); - kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n); + kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, beta2, weight_decay, step, lr, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } break; @@ -183,12 +183,12 @@ template void optimizerStatic8bit(T* p, T* g, break; case LION: // in lion, the momentum update happens after the parameter update - kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr, + kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); - kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, beta2, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; default: