do a bunch of typical bookkeeping before getting to main lion logic
This commit is contained in:
parent
d43ea9722c
commit
cb4c3c8c66
|
@ -40,7 +40,7 @@ out = linear(x.to(torch.float16))
|
|||
## Features
|
||||
- 8-bit Matrix multiplication with mixed precision decomposition
|
||||
- LLM.int8() inference
|
||||
- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB (saves 75% memory)
|
||||
- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB, Lion (saves 75% memory)
|
||||
- Stable Embedding Layer: Improved stability through better initialization, and normalization
|
||||
- 8-bit quantization: Quantile, Linear, and Dynamic quantization
|
||||
- Fast quantile estimation: Up to 100x faster than other algorithms
|
||||
|
|
|
@ -35,6 +35,10 @@ if COMPILED_WITH_CUDA:
|
|||
lib.crmsprop32bit_g32,
|
||||
lib.crmsprop32bit_g16,
|
||||
)
|
||||
str2optimizer32bit["lion"] = (
|
||||
lib.clion32bit_g32,
|
||||
lib.clion32bit_g16,
|
||||
)
|
||||
str2optimizer32bit["adagrad"] = (
|
||||
lib.cadagrad32bit_g32,
|
||||
lib.cadagrad32bit_g16,
|
||||
|
@ -58,6 +62,10 @@ if COMPILED_WITH_CUDA:
|
|||
lib.crmsprop_static_8bit_g32,
|
||||
lib.crmsprop_static_8bit_g16,
|
||||
)
|
||||
str2optimizer8bit["lion"] = (
|
||||
lib.clion_static_8bit_g32,
|
||||
lib.clion_static_8bit_g16,
|
||||
)
|
||||
str2optimizer8bit["lamb"] = (
|
||||
lib.cadam_static_8bit_g32,
|
||||
lib.cadam_static_8bit_g16,
|
||||
|
@ -80,6 +88,10 @@ if COMPILED_WITH_CUDA:
|
|||
lib.crmsprop_8bit_blockwise_fp32,
|
||||
lib.crmsprop_8bit_blockwise_fp16,
|
||||
)
|
||||
str2optimizer8bit_blockwise["lion"] = (
|
||||
lib.clion_8bit_blockwise_fp32,
|
||||
lib.clion_8bit_blockwise_fp16,
|
||||
)
|
||||
str2optimizer8bit_blockwise["adagrad"] = (
|
||||
lib.cadagrad_8bit_blockwise_fp32,
|
||||
lib.cadagrad_8bit_blockwise_fp16,
|
||||
|
|
|
@ -19,7 +19,7 @@ class Lion(Optimizer1State):
|
|||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"rmsprop",
|
||||
"lion",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
|
@ -46,7 +46,7 @@ class Lion8bit(Optimizer1State):
|
|||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"rmsprop",
|
||||
"lion",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
|
@ -73,7 +73,7 @@ class Lion32bit(Optimizer1State):
|
|||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"rmsprop",
|
||||
"lion",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
|
|
|
@ -790,6 +790,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
|
|||
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update
|
||||
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
|
||||
break;
|
||||
case LION:
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update
|
||||
s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
|
||||
|
@ -890,6 +891,7 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
|
|||
|
||||
p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
|
||||
break;
|
||||
case LION:
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
|
||||
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
|
||||
|
@ -1219,6 +1221,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
|
|||
if(unorm != NULL)
|
||||
local_unorm += s1_vals[j]*s1_vals[j];
|
||||
break;
|
||||
case LION:
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
|
||||
break;
|
||||
|
@ -1321,6 +1324,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
|
|||
|
||||
p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j]));
|
||||
break;
|
||||
case LION:
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
|
||||
p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps));
|
||||
|
@ -1664,6 +1668,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
else
|
||||
s1_vals[j] = (s1_vals[j]*beta1) + g_val;
|
||||
break;
|
||||
case LION:
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
|
||||
break;
|
||||
|
@ -1701,6 +1706,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
case MOMENTUM:
|
||||
p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
|
||||
break;
|
||||
case LION:
|
||||
case RMSPROP:
|
||||
g_val = g_vals[j];
|
||||
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
|
||||
|
@ -2699,6 +2705,8 @@ MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half)
|
|||
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float)
|
||||
MAKE_PreconditionOptimizer32bit1State(RMSPROP, half)
|
||||
MAKE_PreconditionOptimizer32bit1State(RMSPROP, float)
|
||||
MAKE_PreconditionOptimizer32bit1State(LION, half)
|
||||
MAKE_PreconditionOptimizer32bit1State(LION, float)
|
||||
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half)
|
||||
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
|
||||
|
||||
|
@ -2710,6 +2718,8 @@ MAKE_Optimizer32bit1State(MOMENTUM, half)
|
|||
MAKE_Optimizer32bit1State(MOMENTUM, float)
|
||||
MAKE_Optimizer32bit1State(RMSPROP, half)
|
||||
MAKE_Optimizer32bit1State(RMSPROP, float)
|
||||
MAKE_Optimizer32bit1State(LION, half)
|
||||
MAKE_Optimizer32bit1State(LION, float)
|
||||
MAKE_Optimizer32bit1State(ADAGRAD, half)
|
||||
MAKE_Optimizer32bit1State(ADAGRAD, float)
|
||||
|
||||
|
@ -2742,6 +2752,8 @@ MAKE_PreconditionStatic8bit1State(MOMENTUM, half)
|
|||
MAKE_PreconditionStatic8bit1State(MOMENTUM, float)
|
||||
MAKE_PreconditionStatic8bit1State(RMSPROP, half)
|
||||
MAKE_PreconditionStatic8bit1State(RMSPROP, float)
|
||||
MAKE_PreconditionStatic8bit1State(LION, half)
|
||||
MAKE_PreconditionStatic8bit1State(LION, float)
|
||||
|
||||
#define MAKE_optimizerStatic8bit1State(oname, gtype) \
|
||||
template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, \
|
||||
|
@ -2758,6 +2770,8 @@ MAKE_optimizerStatic8bit1State(MOMENTUM, half)
|
|||
MAKE_optimizerStatic8bit1State(MOMENTUM, float)
|
||||
MAKE_optimizerStatic8bit1State(RMSPROP, half)
|
||||
MAKE_optimizerStatic8bit1State(RMSPROP, float)
|
||||
MAKE_optimizerStatic8bit1State(LION, half)
|
||||
MAKE_optimizerStatic8bit1State(LION, float)
|
||||
|
||||
#define MAKE_PreconditionStatic8bit2State(oname, gtype) \
|
||||
template __global__ void kPreconditionOptimizerStatic8bit2State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \
|
||||
|
@ -2849,5 +2863,7 @@ MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8)
|
|||
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8)
|
||||
|
|
|
@ -120,6 +120,7 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
|||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
case ADAGRAD:
|
||||
case LION:
|
||||
|
||||
if(max_unorm > 0.0f)
|
||||
{
|
||||
|
@ -163,6 +164,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
|
|||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
case ADAGRAD:
|
||||
case LION:
|
||||
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
|
@ -198,6 +200,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
|
|||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
case ADAGRAD:
|
||||
case LION:
|
||||
num_blocks = n/BLOCKSIZE_1STATE;
|
||||
num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1;
|
||||
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
|
||||
|
@ -707,6 +710,8 @@ MAKE_optimizer32bit(MOMENTUM, half)
|
|||
MAKE_optimizer32bit(MOMENTUM, float)
|
||||
MAKE_optimizer32bit(RMSPROP, half)
|
||||
MAKE_optimizer32bit(RMSPROP, float)
|
||||
MAKE_optimizer32bit(LION, half)
|
||||
MAKE_optimizer32bit(LION, float)
|
||||
MAKE_optimizer32bit(ADAGRAD, half)
|
||||
MAKE_optimizer32bit(ADAGRAD, float)
|
||||
|
||||
|
@ -726,6 +731,8 @@ MAKE_optimizerStatic8bit(MOMENTUM, half)
|
|||
MAKE_optimizerStatic8bit(MOMENTUM, float)
|
||||
MAKE_optimizerStatic8bit(RMSPROP, half)
|
||||
MAKE_optimizerStatic8bit(RMSPROP, float)
|
||||
MAKE_optimizerStatic8bit(LION, half)
|
||||
MAKE_optimizerStatic8bit(LION, float)
|
||||
|
||||
#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
|
||||
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \
|
||||
|
@ -738,6 +745,8 @@ MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
|
|||
MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM);
|
||||
MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
|
||||
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
|
||||
MAKE_optimizerStatic8bitBlockwise(half, LION);
|
||||
MAKE_optimizerStatic8bitBlockwise(float, LION);
|
||||
MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
|
||||
MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
|
||||
|
||||
|
|
|
@ -70,6 +70,7 @@ typedef enum Optimizer_t
|
|||
RMSPROP = 2,
|
||||
LARS = 3,
|
||||
ADAGRAD = 4,
|
||||
LION = 5,
|
||||
} Optimizer_t;
|
||||
|
||||
typedef enum Transform_t
|
||||
|
|
|
@ -33,6 +33,8 @@ MAKE_FUNC32(adam, ADAM, float, 32)
|
|||
MAKE_FUNC32(adam, ADAM, half, 16)
|
||||
MAKE_FUNC32(rmsprop, RMSPROP, float, 32)
|
||||
MAKE_FUNC32(rmsprop, RMSPROP, half, 16)
|
||||
MAKE_FUNC32(lion, LION, float, 32)
|
||||
MAKE_FUNC32(lion, LION, half, 16)
|
||||
MAKE_FUNC32(adagrad, ADAGRAD, float, 32)
|
||||
MAKE_FUNC32(adagrad, ADAGRAD, half, 16)
|
||||
|
||||
|
@ -55,6 +57,8 @@ MAKE_FUNC8(momentum, MOMENTUM, float, 32)
|
|||
MAKE_FUNC8(momentum, MOMENTUM, half, 16)
|
||||
MAKE_FUNC8(rmsprop, RMSPROP, float, 32)
|
||||
MAKE_FUNC8(rmsprop, RMSPROP, half, 16)
|
||||
MAKE_FUNC8(lion, LION, float, 32)
|
||||
MAKE_FUNC8(lion, LION, half, 16)
|
||||
|
||||
#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
|
||||
void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
|
||||
|
@ -68,6 +72,8 @@ MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16)
|
|||
MAKE_BLOCKWISE8(momentum, MOMENTUM, float, 32)
|
||||
MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, 16)
|
||||
MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, 32)
|
||||
MAKE_BLOCKWISE8(lion, LION, half, 16)
|
||||
MAKE_BLOCKWISE8(lion, LION, float, 32)
|
||||
MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, 16)
|
||||
MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32)
|
||||
|
||||
|
@ -161,6 +167,8 @@ extern "C"
|
|||
MAKE_CFUNC32(momentum, half, 16)
|
||||
MAKE_CFUNC32(rmsprop, float, 32)
|
||||
MAKE_CFUNC32(rmsprop, half, 16)
|
||||
MAKE_CFUNC32(lion, float, 32)
|
||||
MAKE_CFUNC32(lion, half, 16)
|
||||
MAKE_CFUNC32(adagrad, float, 32)
|
||||
MAKE_CFUNC32(adagrad, half, 16)
|
||||
|
||||
|
@ -183,6 +191,8 @@ extern "C"
|
|||
MAKE_CFUNC8(momentum, half, 16)
|
||||
MAKE_CFUNC8(rmsprop, float, 32)
|
||||
MAKE_CFUNC8(rmsprop, half, 16)
|
||||
MAKE_CFUNC8(lion, float, 32)
|
||||
MAKE_CFUNC8(lion, half, 16)
|
||||
|
||||
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
|
||||
void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
|
||||
|
@ -196,6 +206,8 @@ extern "C"
|
|||
MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32)
|
||||
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16)
|
||||
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32)
|
||||
MAKE_CBLOCKWISE8(lion, LION, half, 16)
|
||||
MAKE_CBLOCKWISE8(lion, LION, float, 32)
|
||||
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16)
|
||||
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32)
|
||||
|
||||
|
|
|
@ -50,6 +50,10 @@ str2optimizers["rmsprop"] = (
|
|||
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
|
||||
)
|
||||
str2optimizers["rmsprop"] = (
|
||||
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
|
||||
)
|
||||
str2optimizers["adam8bit"] = (
|
||||
torch.optim.Adam,
|
||||
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False),
|
||||
|
|
Loading…
Reference in New Issue
Block a user