commit
b0ec20c3b3
|
@ -40,7 +40,7 @@ out = linear(x.to(torch.float16))
|
|||
## Features
|
||||
- 8-bit Matrix multiplication with mixed precision decomposition
|
||||
- LLM.int8() inference
|
||||
- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB (saves 75% memory)
|
||||
- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB, Lion (saves 75% memory)
|
||||
- Stable Embedding Layer: Improved stability through better initialization, and normalization
|
||||
- 8-bit quantization: Quantile, Linear, and Dynamic quantization
|
||||
- Fast quantile estimation: Up to 100x faster than other algorithms
|
||||
|
|
|
@ -35,6 +35,10 @@ if COMPILED_WITH_CUDA:
|
|||
lib.crmsprop32bit_g32,
|
||||
lib.crmsprop32bit_g16,
|
||||
)
|
||||
str2optimizer32bit["lion"] = (
|
||||
lib.clion32bit_g32,
|
||||
lib.clion32bit_g16,
|
||||
)
|
||||
str2optimizer32bit["adagrad"] = (
|
||||
lib.cadagrad32bit_g32,
|
||||
lib.cadagrad32bit_g16,
|
||||
|
@ -58,6 +62,10 @@ if COMPILED_WITH_CUDA:
|
|||
lib.crmsprop_static_8bit_g32,
|
||||
lib.crmsprop_static_8bit_g16,
|
||||
)
|
||||
str2optimizer8bit["lion"] = (
|
||||
lib.clion_static_8bit_g32,
|
||||
lib.clion_static_8bit_g16,
|
||||
)
|
||||
str2optimizer8bit["lamb"] = (
|
||||
lib.cadam_static_8bit_g32,
|
||||
lib.cadam_static_8bit_g16,
|
||||
|
@ -80,6 +88,10 @@ if COMPILED_WITH_CUDA:
|
|||
lib.crmsprop_8bit_blockwise_fp32,
|
||||
lib.crmsprop_8bit_blockwise_fp16,
|
||||
)
|
||||
str2optimizer8bit_blockwise["lion"] = (
|
||||
lib.clion_8bit_blockwise_fp32,
|
||||
lib.clion_8bit_blockwise_fp16,
|
||||
)
|
||||
str2optimizer8bit_blockwise["adagrad"] = (
|
||||
lib.cadagrad_8bit_blockwise_fp32,
|
||||
lib.cadagrad_8bit_blockwise_fp16,
|
||||
|
|
|
@ -12,4 +12,5 @@ from .lamb import LAMB, LAMB8bit, LAMB32bit
|
|||
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
|
||||
from .optimizer import GlobalOptimManager
|
||||
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
|
||||
from .lion import Lion, Lion8bit, Lion32bit
|
||||
from .sgd import SGD, SGD8bit, SGD32bit
|
||||
|
|
87
bitsandbytes/optim/lion.py
Normal file
87
bitsandbytes/optim/lion.py
Normal file
|
@ -0,0 +1,87 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from bitsandbytes.optim.optimizer import Optimizer1State
|
||||
|
||||
|
||||
class Lion(Optimizer1State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-4,
|
||||
betas=(0.9, 0.99),
|
||||
weight_decay=0,
|
||||
optim_bits=32,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"lion",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
0.,
|
||||
weight_decay,
|
||||
optim_bits,
|
||||
args,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
)
|
||||
|
||||
|
||||
class Lion8bit(Optimizer1State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-4,
|
||||
betas=(0.9, 0.99),
|
||||
weight_decay=0,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"lion",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
0.,
|
||||
weight_decay,
|
||||
8,
|
||||
args,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
)
|
||||
|
||||
|
||||
class Lion32bit(Optimizer1State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-4,
|
||||
betas=(0.9, 0.99),
|
||||
weight_decay=0,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"lion",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
0.,
|
||||
weight_decay,
|
||||
32,
|
||||
args,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
)
|
|
@ -43,6 +43,14 @@ __device__ float atomicMin(float* address, float val) {
|
|||
return __int_as_float(old);
|
||||
}
|
||||
|
||||
// sign function for lion
|
||||
// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA
|
||||
|
||||
template <typename T>
|
||||
__device__ int sgn(T val) {
|
||||
return (T(0) < val) - (val < T(0));
|
||||
}
|
||||
|
||||
template <int STOCHASTIC>
|
||||
__device__ unsigned char dQuantize(float* smem_code, const float rand, float x)
|
||||
{
|
||||
|
@ -743,7 +751,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
|
|||
__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1)
|
||||
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
|
||||
float* state1, float *unorm,
|
||||
const float beta1, const float eps, const float weight_decay,
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,
|
||||
const int step, const float lr, const float gnorm_scale, const int n)
|
||||
{
|
||||
|
||||
|
@ -790,6 +798,9 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
|
|||
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update
|
||||
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
|
||||
break;
|
||||
case LION:
|
||||
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*(float)g_vals[j]); // state update
|
||||
break;
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update
|
||||
s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
|
||||
|
@ -821,7 +832,7 @@ template<typename T, int OPTIMIZER>
|
|||
__launch_bounds__(TH, 1)
|
||||
__global__ void kOptimizer32bit1State(T *g, T *p,
|
||||
float *state1, float *unorm, const float max_unorm, const float param_norm,
|
||||
const float beta1, const float eps, const float weight_decay,
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,
|
||||
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
|
||||
{
|
||||
|
||||
|
@ -890,6 +901,10 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
|
|||
|
||||
p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
|
||||
break;
|
||||
case LION:
|
||||
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))));
|
||||
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j]));
|
||||
break;
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
|
||||
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
|
||||
|
@ -1158,7 +1173,7 @@ __global__ void
|
|||
__launch_bounds__(NUM_THREADS, 2)
|
||||
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
|
||||
float *unorm,
|
||||
const float beta1,
|
||||
const float beta1, const float beta2,
|
||||
const float eps, const int step,
|
||||
float* __restrict__ const quantiles1,
|
||||
float* max1, float* new_max1,
|
||||
|
@ -1219,6 +1234,9 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
|
|||
if(unorm != NULL)
|
||||
local_unorm += s1_vals[j]*s1_vals[j];
|
||||
break;
|
||||
case LION:
|
||||
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
|
||||
break;
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
|
||||
break;
|
||||
|
@ -1244,7 +1262,7 @@ template<typename T, int OPTIMIZER>
|
|||
__global__ void
|
||||
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
|
||||
const float *unorm, const float max_unorm, const float param_norm,
|
||||
const float beta1,
|
||||
const float beta1, const float beta2,
|
||||
const float eps, const int step, const float lr,
|
||||
float* __restrict__ const quantiles1,
|
||||
float* max1, float* new_max1,
|
||||
|
@ -1307,8 +1325,19 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
|
|||
{
|
||||
g_val = float(g_vals[j]);
|
||||
g_val *= gnorm_scale;
|
||||
if(weight_decay > 0.0f)
|
||||
g_val += ((float)p_vals[j])*weight_decay;
|
||||
|
||||
if(weight_decay > 0.0f) {
|
||||
switch(OPTIMIZER) {
|
||||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
g_val += ((float)p_vals[j])*weight_decay;
|
||||
break;
|
||||
case LION:
|
||||
p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0];
|
||||
|
||||
switch(OPTIMIZER)
|
||||
|
@ -1321,6 +1350,10 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
|
|||
|
||||
p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j]));
|
||||
break;
|
||||
case LION:
|
||||
p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val))));
|
||||
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
|
||||
break;
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
|
||||
p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps));
|
||||
|
@ -1649,10 +1682,20 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
{
|
||||
g_val = float(g_vals[j]);
|
||||
g_val *= gnorm_scale;
|
||||
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
|
||||
{
|
||||
if(weight_decay > 0.0f)
|
||||
g_val += ((float)p_vals[j])*weight_decay;
|
||||
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
|
||||
{
|
||||
if(weight_decay > 0.0f) {
|
||||
switch(OPTIMIZER) {
|
||||
case MOMENTUM:
|
||||
case ADAGRAD:
|
||||
case RMSPROP:
|
||||
g_val += ((float)p_vals[j])*weight_decay;
|
||||
break;
|
||||
case LION:
|
||||
p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
|
||||
|
||||
|
@ -1664,6 +1707,11 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
else
|
||||
s1_vals[j] = (s1_vals[j]*beta1) + g_val;
|
||||
break;
|
||||
case LION:
|
||||
// here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, before the momentum is updated by beta2
|
||||
g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val));
|
||||
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
|
||||
break;
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
|
||||
break;
|
||||
|
@ -1701,6 +1749,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
case MOMENTUM:
|
||||
p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
|
||||
break;
|
||||
case LION:
|
||||
p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]);
|
||||
break;
|
||||
case RMSPROP:
|
||||
g_val = g_vals[j];
|
||||
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
|
||||
|
@ -2692,24 +2743,28 @@ template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *c
|
|||
#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \
|
||||
template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
|
||||
float* state1, float *unorm, \
|
||||
const float beta1, const float eps, const float weight_decay, \
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay, \
|
||||
const int step, const float lr, const float gnorm_scale, const int n); \
|
||||
|
||||
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half)
|
||||
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float)
|
||||
MAKE_PreconditionOptimizer32bit1State(RMSPROP, half)
|
||||
MAKE_PreconditionOptimizer32bit1State(RMSPROP, float)
|
||||
MAKE_PreconditionOptimizer32bit1State(LION, half)
|
||||
MAKE_PreconditionOptimizer32bit1State(LION, float)
|
||||
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half)
|
||||
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
|
||||
|
||||
#define MAKE_Optimizer32bit1State(oname, gtype) \
|
||||
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
|
||||
const float beta1, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
|
||||
|
||||
MAKE_Optimizer32bit1State(MOMENTUM, half)
|
||||
MAKE_Optimizer32bit1State(MOMENTUM, float)
|
||||
MAKE_Optimizer32bit1State(RMSPROP, half)
|
||||
MAKE_Optimizer32bit1State(RMSPROP, float)
|
||||
MAKE_Optimizer32bit1State(LION, half)
|
||||
MAKE_Optimizer32bit1State(LION, float)
|
||||
MAKE_Optimizer32bit1State(ADAGRAD, half)
|
||||
MAKE_Optimizer32bit1State(ADAGRAD, float)
|
||||
|
||||
|
@ -2731,6 +2786,7 @@ template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p,
|
|||
template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \
|
||||
float *unorm, \
|
||||
const float beta1, \
|
||||
const float beta2, \
|
||||
const float eps, const int step, \
|
||||
float* __restrict__ const quantiles1, \
|
||||
float* max1, float* new_max1, \
|
||||
|
@ -2742,11 +2798,14 @@ MAKE_PreconditionStatic8bit1State(MOMENTUM, half)
|
|||
MAKE_PreconditionStatic8bit1State(MOMENTUM, float)
|
||||
MAKE_PreconditionStatic8bit1State(RMSPROP, half)
|
||||
MAKE_PreconditionStatic8bit1State(RMSPROP, float)
|
||||
MAKE_PreconditionStatic8bit1State(LION, half)
|
||||
MAKE_PreconditionStatic8bit1State(LION, float)
|
||||
|
||||
#define MAKE_optimizerStatic8bit1State(oname, gtype) \
|
||||
template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, \
|
||||
const float *unorm, const float max_unorm, const float param_norm, \
|
||||
const float beta1, \
|
||||
const float beta2, \
|
||||
const float eps, const int step, const float lr, \
|
||||
float* __restrict__ const quantiles1, \
|
||||
float* max1, float* new_max1, \
|
||||
|
@ -2758,6 +2817,8 @@ MAKE_optimizerStatic8bit1State(MOMENTUM, half)
|
|||
MAKE_optimizerStatic8bit1State(MOMENTUM, float)
|
||||
MAKE_optimizerStatic8bit1State(RMSPROP, half)
|
||||
MAKE_optimizerStatic8bit1State(RMSPROP, float)
|
||||
MAKE_optimizerStatic8bit1State(LION, half)
|
||||
MAKE_optimizerStatic8bit1State(LION, float)
|
||||
|
||||
#define MAKE_PreconditionStatic8bit2State(oname, gtype) \
|
||||
template __global__ void kPreconditionOptimizerStatic8bit2State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \
|
||||
|
@ -2849,5 +2910,7 @@ MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8)
|
|||
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8)
|
||||
|
|
|
@ -32,20 +32,20 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
|
|||
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
|
||||
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
|
||||
float* state1, float *unorm,
|
||||
const float beta1, const float eps, const float weight_decay,
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,
|
||||
const int step, const float lr, const float gnorm_scale, const int n);
|
||||
|
||||
template<typename T, int OPTIMIZER>
|
||||
__global__ void kOptimizer32bit1State(T* g, T* p,
|
||||
float* state1, float *unorm, const float max_unorm, const float param_norm,
|
||||
const float beta1, const float eps, const float weight_decay,
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,
|
||||
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
|
||||
|
||||
template<typename T, int OPTIMIZER>
|
||||
__global__ void
|
||||
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
|
||||
float *unorm,
|
||||
const float beta1,
|
||||
const float beta1, const float beta2,
|
||||
const float eps, const int step,
|
||||
float* __restrict__ const quantiles1,
|
||||
float* max1, float* new_max1,
|
||||
|
@ -57,7 +57,7 @@ template<typename T, int OPTIMIZER>
|
|||
__global__ void
|
||||
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
|
||||
const float *unorm, const float max_unorm, const float param_norm,
|
||||
const float beta1,
|
||||
const float beta1, const float beta2,
|
||||
const float eps, const int step, const float lr,
|
||||
float* __restrict__ const quantiles1,
|
||||
float* max1, float* new_max1,
|
||||
|
|
38
csrc/ops.cu
38
csrc/ops.cu
|
@ -120,17 +120,28 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
|||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
case ADAGRAD:
|
||||
|
||||
if(max_unorm > 0.0f)
|
||||
{
|
||||
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
||||
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
case LION:
|
||||
// in lion, the momentum update after the parameter update
|
||||
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
|
||||
if(max_unorm > 0.0f)
|
||||
{
|
||||
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -164,12 +175,22 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
|
|||
case RMSPROP:
|
||||
case ADAGRAD:
|
||||
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
|
||||
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
|
||||
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
case LION:
|
||||
// in lion, the momentum update happens after the parameter update
|
||||
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
|
||||
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
|
||||
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -198,6 +219,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
|
|||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
case ADAGRAD:
|
||||
case LION:
|
||||
num_blocks = n/BLOCKSIZE_1STATE;
|
||||
num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1;
|
||||
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
|
||||
|
@ -707,6 +729,8 @@ MAKE_optimizer32bit(MOMENTUM, half)
|
|||
MAKE_optimizer32bit(MOMENTUM, float)
|
||||
MAKE_optimizer32bit(RMSPROP, half)
|
||||
MAKE_optimizer32bit(RMSPROP, float)
|
||||
MAKE_optimizer32bit(LION, half)
|
||||
MAKE_optimizer32bit(LION, float)
|
||||
MAKE_optimizer32bit(ADAGRAD, half)
|
||||
MAKE_optimizer32bit(ADAGRAD, float)
|
||||
|
||||
|
@ -726,6 +750,8 @@ MAKE_optimizerStatic8bit(MOMENTUM, half)
|
|||
MAKE_optimizerStatic8bit(MOMENTUM, float)
|
||||
MAKE_optimizerStatic8bit(RMSPROP, half)
|
||||
MAKE_optimizerStatic8bit(RMSPROP, float)
|
||||
MAKE_optimizerStatic8bit(LION, half)
|
||||
MAKE_optimizerStatic8bit(LION, float)
|
||||
|
||||
#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
|
||||
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \
|
||||
|
@ -738,6 +764,8 @@ MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
|
|||
MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM);
|
||||
MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
|
||||
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
|
||||
MAKE_optimizerStatic8bitBlockwise(half, LION);
|
||||
MAKE_optimizerStatic8bitBlockwise(float, LION);
|
||||
MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
|
||||
MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
|
||||
|
||||
|
|
|
@ -70,6 +70,7 @@ typedef enum Optimizer_t
|
|||
RMSPROP = 2,
|
||||
LARS = 3,
|
||||
ADAGRAD = 4,
|
||||
LION = 5,
|
||||
} Optimizer_t;
|
||||
|
||||
typedef enum Transform_t
|
||||
|
|
|
@ -33,6 +33,8 @@ MAKE_FUNC32(adam, ADAM, float, 32)
|
|||
MAKE_FUNC32(adam, ADAM, half, 16)
|
||||
MAKE_FUNC32(rmsprop, RMSPROP, float, 32)
|
||||
MAKE_FUNC32(rmsprop, RMSPROP, half, 16)
|
||||
MAKE_FUNC32(lion, LION, float, 32)
|
||||
MAKE_FUNC32(lion, LION, half, 16)
|
||||
MAKE_FUNC32(adagrad, ADAGRAD, float, 32)
|
||||
MAKE_FUNC32(adagrad, ADAGRAD, half, 16)
|
||||
|
||||
|
@ -55,6 +57,8 @@ MAKE_FUNC8(momentum, MOMENTUM, float, 32)
|
|||
MAKE_FUNC8(momentum, MOMENTUM, half, 16)
|
||||
MAKE_FUNC8(rmsprop, RMSPROP, float, 32)
|
||||
MAKE_FUNC8(rmsprop, RMSPROP, half, 16)
|
||||
MAKE_FUNC8(lion, LION, float, 32)
|
||||
MAKE_FUNC8(lion, LION, half, 16)
|
||||
|
||||
#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
|
||||
void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
|
||||
|
@ -68,6 +72,8 @@ MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16)
|
|||
MAKE_BLOCKWISE8(momentum, MOMENTUM, float, 32)
|
||||
MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, 16)
|
||||
MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, 32)
|
||||
MAKE_BLOCKWISE8(lion, LION, half, 16)
|
||||
MAKE_BLOCKWISE8(lion, LION, float, 32)
|
||||
MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, 16)
|
||||
MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32)
|
||||
|
||||
|
@ -161,6 +167,8 @@ extern "C"
|
|||
MAKE_CFUNC32(momentum, half, 16)
|
||||
MAKE_CFUNC32(rmsprop, float, 32)
|
||||
MAKE_CFUNC32(rmsprop, half, 16)
|
||||
MAKE_CFUNC32(lion, float, 32)
|
||||
MAKE_CFUNC32(lion, half, 16)
|
||||
MAKE_CFUNC32(adagrad, float, 32)
|
||||
MAKE_CFUNC32(adagrad, half, 16)
|
||||
|
||||
|
@ -183,6 +191,8 @@ extern "C"
|
|||
MAKE_CFUNC8(momentum, half, 16)
|
||||
MAKE_CFUNC8(rmsprop, float, 32)
|
||||
MAKE_CFUNC8(rmsprop, half, 16)
|
||||
MAKE_CFUNC8(lion, float, 32)
|
||||
MAKE_CFUNC8(lion, half, 16)
|
||||
|
||||
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
|
||||
void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
|
||||
|
@ -196,6 +206,8 @@ extern "C"
|
|||
MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32)
|
||||
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16)
|
||||
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32)
|
||||
MAKE_CBLOCKWISE8(lion, LION, half, 16)
|
||||
MAKE_CBLOCKWISE8(lion, LION, float, 32)
|
||||
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16)
|
||||
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32)
|
||||
|
||||
|
|
|
@ -1 +1,2 @@
|
|||
lion-pytorch
|
||||
pytest
|
||||
|
|
|
@ -7,6 +7,8 @@ from itertools import product
|
|||
from os.path import join
|
||||
|
||||
import pytest
|
||||
from lion_pytorch import Lion
|
||||
|
||||
import torch
|
||||
|
||||
import bitsandbytes as bnb
|
||||
|
@ -31,6 +33,7 @@ str2optimizers = {}
|
|||
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
|
||||
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
|
||||
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
|
||||
str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
|
||||
str2optimizers["momentum_pytorch"] = (
|
||||
None,
|
||||
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
||||
|
@ -38,6 +41,7 @@ str2optimizers["momentum_pytorch"] = (
|
|||
)
|
||||
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
|
||||
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
|
||||
str2optimizers["lion"] = (Lion, bnb.optim.Lion)
|
||||
str2optimizers["momentum"] = (
|
||||
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
|
||||
|
@ -54,6 +58,10 @@ str2optimizers["adam8bit"] = (
|
|||
torch.optim.Adam,
|
||||
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False),
|
||||
)
|
||||
str2optimizers["lion8bit"] = (
|
||||
Lion,
|
||||
lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False),
|
||||
)
|
||||
str2optimizers["momentum8bit"] = (
|
||||
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
|
||||
|
@ -71,6 +79,10 @@ str2optimizers["adam8bit_blockwise"] = (
|
|||
torch.optim.Adam,
|
||||
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True),
|
||||
)
|
||||
str2optimizers["lion8bit_blockwise"] = (
|
||||
Lion,
|
||||
lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True),
|
||||
)
|
||||
str2optimizers["momentum8bit_blockwise"] = (
|
||||
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
|
||||
|
@ -82,6 +94,7 @@ str2optimizers["rmsprop8bit_blockwise"] = (
|
|||
|
||||
str2statenames = {}
|
||||
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
||||
str2statenames["lion"] = [("exp_avg", "state1")]
|
||||
str2statenames["momentum"] = [("momentum_buffer", "state1")]
|
||||
str2statenames["lars"] = [("momentum_buffer", "state1")]
|
||||
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
||||
|
@ -90,6 +103,9 @@ str2statenames["adam8bit"] = [
|
|||
("exp_avg", "state1", "qmap1", "max1"),
|
||||
("exp_avg_sq", "state2", "qmap2", "max2"),
|
||||
]
|
||||
str2statenames["lion8bit"] = [
|
||||
("exp_avg", "state1", "qmap1", "max1")
|
||||
]
|
||||
str2statenames["lamb8bit"] = [
|
||||
("exp_avg", "state1", "qmap1", "max1"),
|
||||
("exp_avg_sq", "state2", "qmap2", "max2"),
|
||||
|
@ -98,6 +114,9 @@ str2statenames["adam8bit_blockwise"] = [
|
|||
("exp_avg", "state1", "qmap1", "absmax1"),
|
||||
("exp_avg_sq", "state2", "qmap2", "absmax2"),
|
||||
]
|
||||
str2statenames["lion8bit_blockwise"] = [
|
||||
("exp_avg", "state1", "qmap1", "absmax1")
|
||||
]
|
||||
str2statenames["momentum8bit"] = [
|
||||
("momentum_buffer", "state1", "qmap1", "max1")
|
||||
]
|
||||
|
@ -113,7 +132,7 @@ str2statenames["rmsprop8bit_blockwise"] = [
|
|||
dim1 = [1024]
|
||||
dim2 = [32, 1024, 4097, 1]
|
||||
gtype = [torch.float32, torch.float16]
|
||||
optimizer_names = ["adam", "momentum", "rmsprop", "lars"]
|
||||
optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lion"]
|
||||
values = list(product(dim1, dim2, gtype, optimizer_names))
|
||||
names = [
|
||||
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
|
||||
|
@ -241,9 +260,11 @@ dim2 = [32, 1024, 4097]
|
|||
gtype = [torch.float32, torch.float16]
|
||||
optimizer_names = [
|
||||
"adam8bit",
|
||||
"lion8bit",
|
||||
"momentum8bit",
|
||||
"rmsprop8bit",
|
||||
"adam8bit_blockwise",
|
||||
"lion8bit_blockwise",
|
||||
"lars8bit",
|
||||
"momentum8bit_blockwise",
|
||||
"rmsprop8bit_blockwise",
|
||||
|
|
Loading…
Reference in New Issue
Block a user