Initial plumbing for skip_zeros.
This commit is contained in:
parent
8400b58cbb
commit
bb34fd50a1
|
@ -337,7 +337,7 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten
|
|||
beta1: float, eps: float, step: int, lr: float,
|
||||
state2: Tensor=None, beta2: float=0.0,
|
||||
weight_decay: float=0.0, gnorm_scale: float=1.0,
|
||||
unorm_vec: Tensor=None, max_unorm: float=0.0) -> None:
|
||||
unorm_vec: Tensor=None, max_unorm: float=0.0, skip_zeros=False) -> None:
|
||||
'''
|
||||
Performs an inplace optimizer update with one or two optimizer states.
|
||||
|
||||
|
@ -369,6 +369,12 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten
|
|||
Optimizer beta2.
|
||||
gnorm_scale : float
|
||||
The factor to rescale the gradient to the max clip value.
|
||||
unorm_vec : torch.Tensor
|
||||
The tensor for the update norm.
|
||||
max_unorm : float
|
||||
The maximum update norm relative to the weight norm.
|
||||
skip_zeros : bool
|
||||
Whether to skip zero-valued gradients or not (default: False).
|
||||
'''
|
||||
|
||||
param_norm = 0.0
|
||||
|
@ -381,11 +387,11 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten
|
|||
if g.dtype == torch.float32 and state1.dtype == torch.float32:
|
||||
str2optimizer32bit[optimizer_name][0](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm),
|
||||
ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay),
|
||||
ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_int32(g.numel()))
|
||||
ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel()))
|
||||
elif g.dtype == torch.float16 and state1.dtype == torch.float32:
|
||||
str2optimizer32bit[optimizer_name][1](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm),
|
||||
ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay),
|
||||
ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_int32(g.numel()))
|
||||
ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel()))
|
||||
else:
|
||||
raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}')
|
||||
|
||||
|
@ -439,6 +445,10 @@ def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Ten
|
|||
Max value for the next Adam update of the second state.
|
||||
gnorm_scale : float
|
||||
The factor to rescale the gradient to the max clip value.
|
||||
unorm_vec : torch.Tensor
|
||||
The tensor for the update norm.
|
||||
max_unorm : float
|
||||
The maximum update norm relative to the weight norm.
|
||||
'''
|
||||
|
||||
param_norm = 0.0
|
||||
|
@ -468,19 +478,22 @@ def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Ten
|
|||
def optimizer_update_8bit_blockwise(optimizer_name: str, g: Tensor, p: Tensor, state1: Tensor, state2: Tensor,
|
||||
beta1: float, beta2: float, eps: float,
|
||||
step: int, lr: float, qmap1: Tensor, qmap2: Tensor,
|
||||
absmax1: Tensor, absmax2: Tensor, weight_decay: float=0.0, gnorm_scale: float=1.0) -> None:
|
||||
absmax1: Tensor, absmax2: Tensor, weight_decay: float=0.0, gnorm_scale: float=1.0,
|
||||
skip_zeros=False) -> None:
|
||||
|
||||
|
||||
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
|
||||
str2optimizer8bit_blockwise[optimizer_name][0](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2),
|
||||
ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps),
|
||||
ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2),
|
||||
get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), ct.c_int32(g.numel()))
|
||||
get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale),
|
||||
ct.c_bool(skip_zeros), ct.c_int32(g.numel()))
|
||||
elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
|
||||
str2optimizer8bit_blockwise[optimizer_name][1](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2),
|
||||
ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps),
|
||||
ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2),
|
||||
get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), ct.c_int32(g.numel()))
|
||||
get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale),
|
||||
ct.c_bool(skip_zeros), ct.c_int32(g.numel()))
|
||||
else:
|
||||
raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}')
|
||||
|
||||
|
|
|
@ -220,6 +220,7 @@ class Optimizer8bit(torch.optim.Optimizer):
|
|||
config['percentile_clipping'] = self.args.percentile_clipping
|
||||
config['block_wise'] = self.args.block_wise
|
||||
config['max_unorm'] = self.args.max_unorm
|
||||
config['skip_zeros'] = self.args.skip_zeros
|
||||
|
||||
if (gindex, pindex) in self.mng.index2config:
|
||||
config.update(self.mng.index2config[(gindex, pindex)])
|
||||
|
@ -234,7 +235,8 @@ class Optimizer8bit(torch.optim.Optimizer):
|
|||
class Optimizer2State(Optimizer8bit):
|
||||
def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
||||
weight_decay=0.0, optim_bits=32, args=None,
|
||||
min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0):
|
||||
min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0,
|
||||
skip_zeros=False):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
|
@ -259,6 +261,7 @@ class Optimizer2State(Optimizer8bit):
|
|||
args['percentile_clipping'] = percentile_clipping
|
||||
args['block_wise'] = block_wise
|
||||
args['max_unorm'] = max_unorm
|
||||
args['skip_zeros'] = skip_zeros
|
||||
|
||||
self.args = MockArgs(args)
|
||||
else:
|
||||
|
@ -355,7 +358,8 @@ class Optimizer2State(Optimizer8bit):
|
|||
class Optimizer1State(Optimizer8bit):
|
||||
def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.0), eps=1e-8,
|
||||
weight_decay=0.0, optim_bits=32, args=None,
|
||||
min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0):
|
||||
min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0,
|
||||
skip_zeros=False):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
|
@ -377,6 +381,7 @@ class Optimizer1State(Optimizer8bit):
|
|||
args['percentile_clipping'] = percentile_clipping
|
||||
args['block_wise'] = block_wise
|
||||
args['max_unorm'] = max_unorm
|
||||
args['skip_zeros'] = skip_zeros
|
||||
|
||||
self.args = MockArgs(args)
|
||||
else:
|
||||
|
@ -444,7 +449,8 @@ class Optimizer1State(Optimizer8bit):
|
|||
if state['state1'].dtype == torch.float:
|
||||
F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
|
||||
None, 0.0, config['weight_decay'], gnorm_scale,
|
||||
state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
|
||||
state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'],
|
||||
skip_zeros=False)
|
||||
|
||||
elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
|
||||
F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
|
||||
|
@ -457,4 +463,4 @@ class Optimizer1State(Optimizer8bit):
|
|||
F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
|
||||
config['eps'], step, config['lr'],
|
||||
state['qmap1'], None, state['absmax1'], None,
|
||||
config['weight_decay'], gnorm_scale=gnorm_scale)
|
||||
config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=False)
|
||||
|
|
|
@ -654,7 +654,7 @@ __launch_bounds__(TH, 1)
|
|||
__global__ void kOptimizer32bit2State(T* g, T* p,
|
||||
float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,
|
||||
const int step, const float lr, const float gnorm_scale, const int n)
|
||||
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
|
||||
{
|
||||
|
||||
const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD));
|
||||
|
@ -809,7 +809,7 @@ __launch_bounds__(TH, 1)
|
|||
__global__ void kOptimizer32bit1State(T *g, T *p,
|
||||
float *state1, float *unorm, const float max_unorm, const float param_norm,
|
||||
const float beta1, const float eps, const float weight_decay,
|
||||
const int step, const float lr, const float gnorm_scale, const int n)
|
||||
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
|
||||
{
|
||||
|
||||
const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD));
|
||||
|
@ -1383,7 +1383,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
|
||||
float* absmax1, float* absmax2,
|
||||
float weight_decay,
|
||||
const float gnorm_scale, const int n)
|
||||
const float gnorm_scale, const bool skip_zeros, const int n)
|
||||
{
|
||||
|
||||
//const int n_full = n + (n%BLOCK_SIZE);
|
||||
|
@ -1555,7 +1555,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
float* __restrict__ const quantiles1,
|
||||
float* absmax1,
|
||||
float weight_decay,
|
||||
const float gnorm_scale, const int n)
|
||||
const float gnorm_scale, const bool skip_zeros, const int n)
|
||||
{
|
||||
|
||||
//const int n_full = n + (n%BLOCK_SIZE);
|
||||
|
@ -1723,7 +1723,7 @@ MAKE_PreconditionOptimizer32bit1State(RMSPROP, float)
|
|||
|
||||
#define MAKE_Optimizer32bit1State(oname, gtype) \
|
||||
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
|
||||
const float beta1, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const int n); \
|
||||
const float beta1, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
|
||||
|
||||
MAKE_Optimizer32bit1State(MOMENTUM, half)
|
||||
MAKE_Optimizer32bit1State(MOMENTUM, float)
|
||||
|
@ -1740,9 +1740,9 @@ MAKE_PreconditionOptimizer32bit2State(ADAM, half)
|
|||
MAKE_PreconditionOptimizer32bit2State(ADAM, float)
|
||||
|
||||
template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const int n);
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
|
||||
template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const int n);
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
|
||||
|
||||
#define MAKE_PreconditionStatic8bit1State(oname, gtype) \
|
||||
template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \
|
||||
|
@ -1825,7 +1825,7 @@ template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block
|
|||
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
|
||||
float* absmax1, float* absmax2, \
|
||||
float weight_decay, \
|
||||
const float gnorm_scale, const int n); \
|
||||
const float gnorm_scale, const bool skip_zeros, const int n); \
|
||||
|
||||
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 2048, 8)
|
||||
|
@ -1838,7 +1838,7 @@ template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block
|
|||
float* __restrict__ const quantiles1, \
|
||||
float* absmax1, \
|
||||
float weight_decay, \
|
||||
const float gnorm_scale, const int n); \
|
||||
const float gnorm_scale, const bool skip_zeros, const int n); \
|
||||
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8)
|
||||
|
|
|
@ -27,7 +27,7 @@ template<typename T, int OPTIMIZER>
|
|||
__global__ void kOptimizer32bit2State(T* g, T* p,
|
||||
float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,
|
||||
const int step, const float lr, const float gnorm_scale, const int n);
|
||||
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
|
||||
|
||||
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
|
||||
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
|
||||
|
@ -39,7 +39,7 @@ template<typename T, int OPTIMIZER>
|
|||
__global__ void kOptimizer32bit1State(T* g, T* p,
|
||||
float* state1, float *unorm, const float max_unorm, const float param_norm,
|
||||
const float beta1, const float eps, const float weight_decay,
|
||||
const int step, const float lr, const float gnorm_scale, const int n);
|
||||
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
|
||||
|
||||
template<typename T, int OPTIMIZER>
|
||||
__global__ void
|
||||
|
@ -90,7 +90,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ voi
|
|||
T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2,
|
||||
const float beta1, const float beta2, const float eps, const int step, const float lr,
|
||||
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
|
||||
float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const int n);
|
||||
float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n);
|
||||
|
||||
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit1StateBlockwise(
|
||||
T* p, T* __restrict__ const g, unsigned char* state1,
|
||||
|
@ -99,7 +99,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ voi
|
|||
float* __restrict__ const quantiles1,
|
||||
float* absmax1,
|
||||
float weight_decay,
|
||||
const float gnorm_scale, const int n);
|
||||
const float gnorm_scale, const bool skip_zeros, const int n);
|
||||
|
||||
|
||||
template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n);
|
||||
|
|
16
csrc/ops.cu
16
csrc/ops.cu
|
@ -181,7 +181,7 @@ template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, flo
|
|||
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
||||
float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,
|
||||
const int step, const float lr, const float gnorm_scale, const int n)
|
||||
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n)
|
||||
{
|
||||
int blocks = n/4096;
|
||||
blocks = n % 4096 == 0 ? blocks : blocks + 1;
|
||||
|
@ -194,7 +194,7 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
|||
kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
kOptimizer32bit2State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||
kOptimizer32bit2State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
case MOMENTUM:
|
||||
|
@ -206,7 +206,7 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
|||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
kOptimizer32bit1State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||
kOptimizer32bit1State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
}
|
||||
|
@ -259,7 +259,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
|
|||
|
||||
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g,
|
||||
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
|
||||
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n)
|
||||
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)
|
||||
{
|
||||
|
||||
int blocks = 0;
|
||||
|
@ -269,7 +269,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
|
|||
blocks = n/BLOCKSIZE_2STATE;
|
||||
blocks = n % BLOCKSIZE_2STATE == 0 ? blocks : blocks + 1;
|
||||
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr,
|
||||
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, n);
|
||||
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
case MOMENTUM:
|
||||
|
@ -277,7 +277,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
|
|||
blocks = n/BLOCKSIZE_1STATE;
|
||||
blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1;
|
||||
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
|
||||
quantiles1, absmax1, weight_decay, gnorm_scale, n);
|
||||
quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
}
|
||||
|
@ -313,7 +313,7 @@ template void dequantizeBlockwise<float>(float *code, unsigned char *A, float *a
|
|||
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
|
||||
float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay, \
|
||||
const int step, const float lr, const float gnorm_scale, const int n);
|
||||
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
|
||||
|
||||
MAKE_optimizer32bit(ADAM, half)
|
||||
MAKE_optimizer32bit(ADAM, float)
|
||||
|
@ -342,7 +342,7 @@ MAKE_optimizerStatic8bit(RMSPROP, float)
|
|||
#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
|
||||
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \
|
||||
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
|
||||
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n); \
|
||||
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \
|
||||
|
||||
MAKE_optimizerStatic8bitBlockwise(half, ADAM);
|
||||
MAKE_optimizerStatic8bitBlockwise(float, ADAM);
|
||||
|
|
|
@ -49,7 +49,7 @@ template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, flo
|
|||
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
||||
float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
|
||||
float beta1, float beta2, float eps, float weight_decay,
|
||||
int step, float lr, const float gnorm_scale, int n);
|
||||
int step, float lr, const float gnorm_scale, bool skip_zeros, int n);
|
||||
|
||||
template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2,
|
||||
float *unorm, float max_unorm, float param_norm,
|
||||
|
@ -62,7 +62,8 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, unsigne
|
|||
|
||||
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g,
|
||||
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
|
||||
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n);
|
||||
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale,
|
||||
bool skip_zeros, int n);
|
||||
|
||||
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n);
|
||||
|
||||
|
|
|
@ -20,8 +20,8 @@ void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimate
|
|||
void fname##32bit_g##gbits(gtype *g, gtype *p, \
|
||||
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay, \
|
||||
const int step, const float lr, float gnorm_scale, const int n) \
|
||||
{ optimizer32bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); } \
|
||||
const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n) \
|
||||
{ optimizer32bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
|
||||
|
||||
MAKE_FUNC32(momentum, MOMENTUM, float, 32)
|
||||
MAKE_FUNC32(momentum, MOMENTUM, half, 16)
|
||||
|
@ -53,8 +53,8 @@ MAKE_FUNC8(rmsprop, RMSPROP, half, 16)
|
|||
#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
|
||||
void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
|
||||
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
|
||||
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n)\
|
||||
{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, n); }\
|
||||
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\
|
||||
{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\
|
||||
|
||||
MAKE_BLOCKWISE8(adam, ADAM, half, 16)
|
||||
MAKE_BLOCKWISE8(adam, ADAM, float, 32)
|
||||
|
@ -93,8 +93,8 @@ extern "C"
|
|||
void c##name##32bit_g##gbits(gtype *g, gtype *p, \
|
||||
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay, \
|
||||
const int step, const float lr, const float gnorm_scale, const int n) \
|
||||
{ name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); } \
|
||||
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \
|
||||
{ name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
|
||||
|
||||
MAKE_CFUNC32(adam, float, 32)
|
||||
MAKE_CFUNC32(adam, half, 16)
|
||||
|
@ -110,7 +110,7 @@ extern "C"
|
|||
float eps, int step, float lr, \
|
||||
float* quantiles1, float* quantiles2, \
|
||||
float* max1, float* max2, float* new_max1, float* new_max2, \
|
||||
float weight_decay, float gnorm_scale, int n) \
|
||||
float weight_decay, float gnorm_scale, bool skip_zeros, int n) \
|
||||
{ \
|
||||
name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
|
||||
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
|
||||
|
@ -126,8 +126,8 @@ extern "C"
|
|||
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
|
||||
void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
|
||||
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
|
||||
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n) \
|
||||
{ fname##_8bit_blockwise_fp##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, n); } \
|
||||
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \
|
||||
{ fname##_8bit_blockwise_fp##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \
|
||||
|
||||
MAKE_CBLOCKWISE8(adam, ADAM, half, 16)
|
||||
MAKE_CBLOCKWISE8(adam, ADAM, float, 32)
|
||||
|
|
|
@ -141,6 +141,7 @@ def test_global_config(dim1, dim2, gtype):
|
|||
eps = 1e-8
|
||||
|
||||
bnb.optim.GlobalOptimManager.get_instance().initialize()
|
||||
bnb.optim.GlobalOptimManager.get_instance().override_config(p2, 'skip_zeros', True)
|
||||
bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8)
|
||||
|
||||
bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
|
||||
|
@ -155,6 +156,8 @@ def test_global_config(dim1, dim2, gtype):
|
|||
else:
|
||||
atol, rtol = 1e-4, 1e-3
|
||||
|
||||
original_p2 = p2[mask].clone()
|
||||
|
||||
for i in range(50):
|
||||
g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
|
||||
g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
|
||||
|
@ -163,11 +166,32 @@ def test_global_config(dim1, dim2, gtype):
|
|||
p2.grad = g2
|
||||
p3.grad = g3
|
||||
|
||||
if i > 30 and i % 10 == 0:
|
||||
g1.data[mask] = 0.0
|
||||
g2.data[mask] = 0.0
|
||||
p1.grad = g1
|
||||
p2.grad = g2
|
||||
original_p1 = p1[mask].clone()
|
||||
original_p2 = p2[mask].clone()
|
||||
og_s1 = adam2.state[p2]['state1'][mask].clone()
|
||||
og_s2 = adam2.state[p2]['state2'][mask].clone()
|
||||
og_s11 = adam2.state[p1]['state1'][mask].clone()
|
||||
og_s21 = adam2.state[p1]['state2'][mask].clone()
|
||||
|
||||
adam2.step()
|
||||
|
||||
assert adam2.state[p3]['state1'].dtype == torch.uint8
|
||||
assert adam2.state[p3]['state2'].dtype == torch.uint8
|
||||
|
||||
if i > 30 and i % 10 == 0:
|
||||
torch.testing.assert_allclose(original_p2, p2[mask])
|
||||
torch.testing.assert_allclose(adam2.state[p2]['state1'][mask], og_s1)
|
||||
torch.testing.assert_allclose(adam2.state[p2]['state2'][mask], og_s2)
|
||||
assert ((p1[mask]- original_p1)==0.0).sum() < p1.numel()
|
||||
assert ((adam2.state[p1]['state1'][mask]- og_s11)==0.0).sum() == 0.0
|
||||
assert ((adam2.state[p1]['state2'][mask]- og_s21)==0.0).sum() == 0.0
|
||||
|
||||
|
||||
|
||||
|
||||
dim1 = [1024]
|
||||
|
|
Loading…
Reference in New Issue
Block a user