always pass beta2 into all the 1state functions

This commit is contained in:
Phil Wang 2023-03-10 13:00:59 -08:00
parent abbe65adfc
commit 6c377b39b6
3 changed files with 22 additions and 20 deletions

View File

@ -751,7 +751,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) __launch_bounds__(BLOCK_SIZE/NUM_VALS, 1)
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm, float* state1, float *unorm,
const float beta1, const float eps, const float weight_decay, const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n) const int step, const float lr, const float gnorm_scale, const int n)
{ {
@ -833,7 +833,7 @@ template<typename T, int OPTIMIZER>
__launch_bounds__(TH, 1) __launch_bounds__(TH, 1)
__global__ void kOptimizer32bit1State(T *g, T *p, __global__ void kOptimizer32bit1State(T *g, T *p,
float *state1, float *unorm, const float max_unorm, const float param_norm, float *state1, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float eps, const float weight_decay, const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
{ {
@ -1175,7 +1175,7 @@ __global__ void
__launch_bounds__(NUM_THREADS, 2) __launch_bounds__(NUM_THREADS, 2)
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
float *unorm, float *unorm,
const float beta1, const float beta1, const float beta2,
const float eps, const int step, const float eps, const int step,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles1,
float* max1, float* new_max1, float* max1, float* new_max1,
@ -1238,7 +1238,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
break; break;
case LION: case LION:
// using eps as beta2 // using eps as beta2
s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*g_val); s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
break; break;
case RMSPROP: case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
@ -1265,7 +1265,7 @@ template<typename T, int OPTIMIZER>
__global__ void __global__ void
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
const float *unorm, const float max_unorm, const float param_norm, const float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta1, const float beta2,
const float eps, const int step, const float lr, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles1,
float* max1, float* new_max1, float* max1, float* new_max1,
@ -1356,7 +1356,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
case LION: case LION:
// using eps as beta2 // using eps as beta2
p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val)))); p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val))));
s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*g_val); s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
break; break;
case RMSPROP: case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
@ -2745,7 +2745,7 @@ template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *c
#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ #define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \ template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
float* state1, float *unorm, \ float* state1, float *unorm, \
const float beta1, const float eps, const float weight_decay, \ const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const int n); \ const int step, const float lr, const float gnorm_scale, const int n); \
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half) MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half)
@ -2759,7 +2759,7 @@ MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
#define MAKE_Optimizer32bit1State(oname, gtype) \ #define MAKE_Optimizer32bit1State(oname, gtype) \
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
const float beta1, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \ const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
MAKE_Optimizer32bit1State(MOMENTUM, half) MAKE_Optimizer32bit1State(MOMENTUM, half)
MAKE_Optimizer32bit1State(MOMENTUM, float) MAKE_Optimizer32bit1State(MOMENTUM, float)
@ -2788,6 +2788,7 @@ template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p,
template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \
float *unorm, \ float *unorm, \
const float beta1, \ const float beta1, \
const float beta2, \
const float eps, const int step, \ const float eps, const int step, \
float* __restrict__ const quantiles1, \ float* __restrict__ const quantiles1, \
float* max1, float* new_max1, \ float* max1, float* new_max1, \
@ -2806,6 +2807,7 @@ MAKE_PreconditionStatic8bit1State(LION, float)
template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, \ template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, \
const float *unorm, const float max_unorm, const float param_norm, \ const float *unorm, const float max_unorm, const float param_norm, \
const float beta1, \ const float beta1, \
const float beta2, \
const float eps, const int step, const float lr, \ const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, \ float* __restrict__ const quantiles1, \
float* max1, float* new_max1, \ float* max1, float* new_max1, \

View File

@ -32,20 +32,20 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS> template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm, float* state1, float *unorm,
const float beta1, const float eps, const float weight_decay, const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n); const int step, const float lr, const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER> template<typename T, int OPTIMIZER>
__global__ void kOptimizer32bit1State(T* g, T* p, __global__ void kOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm, const float max_unorm, const float param_norm, float* state1, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float eps, const float weight_decay, const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int OPTIMIZER> template<typename T, int OPTIMIZER>
__global__ void __global__ void
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
float *unorm, float *unorm,
const float beta1, const float beta1, const float beta2,
const float eps, const int step, const float eps, const int step,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles1,
float* max1, float* new_max1, float* max1, float* new_max1,
@ -57,7 +57,7 @@ template<typename T, int OPTIMIZER>
__global__ void __global__ void
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
const float *unorm, const float max_unorm, const float param_norm, const float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta1, const float beta2,
const float eps, const int step, const float lr, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles1,
float* max1, float* new_max1, float* max1, float* new_max1,

View File

@ -123,22 +123,22 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
if(max_unorm > 0.0f) if(max_unorm > 0.0f)
{ {
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n); kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
case LION: case LION:
// in lion, the momentum update after the parameter update // in lion, the momentum update after the parameter update
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, weight_decay, step, lr, gnorm_scale, skip_zeros, n); kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
if(max_unorm > 0.0f) if(max_unorm > 0.0f)
{ {
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, weight_decay, step, lr, gnorm_scale, n); kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
break; break;
@ -175,20 +175,20 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
case RMSPROP: case RMSPROP:
case ADAGRAD: case ADAGRAD:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr, kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
case LION: case LION:
// in lion, the momentum update happens after the parameter update // in lion, the momentum update happens after the parameter update
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, step, lr, kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
default: default: