Initial plumbing for skip_zeros.

This commit is contained in:
Tim Dettmers 2021-10-20 18:37:44 -07:00
parent 8400b58cbb
commit bb34fd50a1
8 changed files with 86 additions and 42 deletions

View File

@ -337,7 +337,7 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten
beta1: float, eps: float, step: int, lr: float, beta1: float, eps: float, step: int, lr: float,
state2: Tensor=None, beta2: float=0.0, state2: Tensor=None, beta2: float=0.0,
weight_decay: float=0.0, gnorm_scale: float=1.0, weight_decay: float=0.0, gnorm_scale: float=1.0,
unorm_vec: Tensor=None, max_unorm: float=0.0) -> None: unorm_vec: Tensor=None, max_unorm: float=0.0, skip_zeros=False) -> None:
''' '''
Performs an inplace optimizer update with one or two optimizer states. Performs an inplace optimizer update with one or two optimizer states.
@ -369,6 +369,12 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten
Optimizer beta2. Optimizer beta2.
gnorm_scale : float gnorm_scale : float
The factor to rescale the gradient to the max clip value. The factor to rescale the gradient to the max clip value.
unorm_vec : torch.Tensor
The tensor for the update norm.
max_unorm : float
The maximum update norm relative to the weight norm.
skip_zeros : bool
Whether to skip zero-valued gradients or not (default: False).
''' '''
param_norm = 0.0 param_norm = 0.0
@ -381,11 +387,11 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten
if g.dtype == torch.float32 and state1.dtype == torch.float32: if g.dtype == torch.float32 and state1.dtype == torch.float32:
str2optimizer32bit[optimizer_name][0](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm), str2optimizer32bit[optimizer_name][0](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm),
ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay), ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay),
ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_int32(g.numel())) ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel()))
elif g.dtype == torch.float16 and state1.dtype == torch.float32: elif g.dtype == torch.float16 and state1.dtype == torch.float32:
str2optimizer32bit[optimizer_name][1](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm), str2optimizer32bit[optimizer_name][1](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm),
ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay), ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay),
ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_int32(g.numel())) ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel()))
else: else:
raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}') raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}')
@ -439,6 +445,10 @@ def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Ten
Max value for the next Adam update of the second state. Max value for the next Adam update of the second state.
gnorm_scale : float gnorm_scale : float
The factor to rescale the gradient to the max clip value. The factor to rescale the gradient to the max clip value.
unorm_vec : torch.Tensor
The tensor for the update norm.
max_unorm : float
The maximum update norm relative to the weight norm.
''' '''
param_norm = 0.0 param_norm = 0.0
@ -468,19 +478,22 @@ def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Ten
def optimizer_update_8bit_blockwise(optimizer_name: str, g: Tensor, p: Tensor, state1: Tensor, state2: Tensor, def optimizer_update_8bit_blockwise(optimizer_name: str, g: Tensor, p: Tensor, state1: Tensor, state2: Tensor,
beta1: float, beta2: float, eps: float, beta1: float, beta2: float, eps: float,
step: int, lr: float, qmap1: Tensor, qmap2: Tensor, step: int, lr: float, qmap1: Tensor, qmap2: Tensor,
absmax1: Tensor, absmax2: Tensor, weight_decay: float=0.0, gnorm_scale: float=1.0) -> None: absmax1: Tensor, absmax2: Tensor, weight_decay: float=0.0, gnorm_scale: float=1.0,
skip_zeros=False) -> None:
if g.dtype == torch.float32 and state1.dtype == torch.uint8: if g.dtype == torch.float32 and state1.dtype == torch.uint8:
str2optimizer8bit_blockwise[optimizer_name][0](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2), str2optimizer8bit_blockwise[optimizer_name][0](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2),
ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps),
ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2), ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2),
get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), ct.c_int32(g.numel())) get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros), ct.c_int32(g.numel()))
elif g.dtype == torch.float16 and state1.dtype == torch.uint8: elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
str2optimizer8bit_blockwise[optimizer_name][1](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2), str2optimizer8bit_blockwise[optimizer_name][1](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2),
ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps),
ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2), ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2),
get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), ct.c_int32(g.numel())) get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros), ct.c_int32(g.numel()))
else: else:
raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}') raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}')

View File

@ -220,6 +220,7 @@ class Optimizer8bit(torch.optim.Optimizer):
config['percentile_clipping'] = self.args.percentile_clipping config['percentile_clipping'] = self.args.percentile_clipping
config['block_wise'] = self.args.block_wise config['block_wise'] = self.args.block_wise
config['max_unorm'] = self.args.max_unorm config['max_unorm'] = self.args.max_unorm
config['skip_zeros'] = self.args.skip_zeros
if (gindex, pindex) in self.mng.index2config: if (gindex, pindex) in self.mng.index2config:
config.update(self.mng.index2config[(gindex, pindex)]) config.update(self.mng.index2config[(gindex, pindex)])
@ -234,7 +235,8 @@ class Optimizer8bit(torch.optim.Optimizer):
class Optimizer2State(Optimizer8bit): class Optimizer2State(Optimizer8bit):
def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0.0, optim_bits=32, args=None, weight_decay=0.0, optim_bits=32, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0): min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0,
skip_zeros=False):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps: if not 0.0 <= eps:
@ -259,6 +261,7 @@ class Optimizer2State(Optimizer8bit):
args['percentile_clipping'] = percentile_clipping args['percentile_clipping'] = percentile_clipping
args['block_wise'] = block_wise args['block_wise'] = block_wise
args['max_unorm'] = max_unorm args['max_unorm'] = max_unorm
args['skip_zeros'] = skip_zeros
self.args = MockArgs(args) self.args = MockArgs(args)
else: else:
@ -355,7 +358,8 @@ class Optimizer2State(Optimizer8bit):
class Optimizer1State(Optimizer8bit): class Optimizer1State(Optimizer8bit):
def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.0), eps=1e-8, def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.0), eps=1e-8,
weight_decay=0.0, optim_bits=32, args=None, weight_decay=0.0, optim_bits=32, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0): min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0,
skip_zeros=False):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps: if not 0.0 <= eps:
@ -377,6 +381,7 @@ class Optimizer1State(Optimizer8bit):
args['percentile_clipping'] = percentile_clipping args['percentile_clipping'] = percentile_clipping
args['block_wise'] = block_wise args['block_wise'] = block_wise
args['max_unorm'] = max_unorm args['max_unorm'] = max_unorm
args['skip_zeros'] = skip_zeros
self.args = MockArgs(args) self.args = MockArgs(args)
else: else:
@ -444,7 +449,8 @@ class Optimizer1State(Optimizer8bit):
if state['state1'].dtype == torch.float: if state['state1'].dtype == torch.float:
F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'], F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
None, 0.0, config['weight_decay'], gnorm_scale, None, 0.0, config['weight_decay'], gnorm_scale,
state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm']) state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'],
skip_zeros=False)
elif state['state1'].dtype == torch.uint8 and not config['block_wise']: elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1], F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
@ -457,4 +463,4 @@ class Optimizer1State(Optimizer8bit):
F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1], F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
config['eps'], step, config['lr'], config['eps'], step, config['lr'],
state['qmap1'], None, state['absmax1'], None, state['qmap1'], None, state['absmax1'], None,
config['weight_decay'], gnorm_scale=gnorm_scale) config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=False)

View File

@ -654,7 +654,7 @@ __launch_bounds__(TH, 1)
__global__ void kOptimizer32bit2State(T* g, T* p, __global__ void kOptimizer32bit2State(T* g, T* p,
float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay, const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n) const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
{ {
const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD));
@ -809,7 +809,7 @@ __launch_bounds__(TH, 1)
__global__ void kOptimizer32bit1State(T *g, T *p, __global__ void kOptimizer32bit1State(T *g, T *p,
float *state1, float *unorm, const float max_unorm, const float param_norm, float *state1, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float eps, const float weight_decay, const float beta1, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n) const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
{ {
const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD));
@ -1383,7 +1383,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* absmax1, float* absmax2, float* absmax1, float* absmax2,
float weight_decay, float weight_decay,
const float gnorm_scale, const int n) const float gnorm_scale, const bool skip_zeros, const int n)
{ {
//const int n_full = n + (n%BLOCK_SIZE); //const int n_full = n + (n%BLOCK_SIZE);
@ -1555,7 +1555,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
float* __restrict__ const quantiles1, float* __restrict__ const quantiles1,
float* absmax1, float* absmax1,
float weight_decay, float weight_decay,
const float gnorm_scale, const int n) const float gnorm_scale, const bool skip_zeros, const int n)
{ {
//const int n_full = n + (n%BLOCK_SIZE); //const int n_full = n + (n%BLOCK_SIZE);
@ -1723,7 +1723,7 @@ MAKE_PreconditionOptimizer32bit1State(RMSPROP, float)
#define MAKE_Optimizer32bit1State(oname, gtype) \ #define MAKE_Optimizer32bit1State(oname, gtype) \
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
const float beta1, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const int n); \ const float beta1, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
MAKE_Optimizer32bit1State(MOMENTUM, half) MAKE_Optimizer32bit1State(MOMENTUM, half)
MAKE_Optimizer32bit1State(MOMENTUM, float) MAKE_Optimizer32bit1State(MOMENTUM, float)
@ -1740,9 +1740,9 @@ MAKE_PreconditionOptimizer32bit2State(ADAM, half)
MAKE_PreconditionOptimizer32bit2State(ADAM, float) MAKE_PreconditionOptimizer32bit2State(ADAM, float)
template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const int n); const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const int n); const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
#define MAKE_PreconditionStatic8bit1State(oname, gtype) \ #define MAKE_PreconditionStatic8bit1State(oname, gtype) \
template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \
@ -1825,7 +1825,7 @@ template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
float* absmax1, float* absmax2, \ float* absmax1, float* absmax2, \
float weight_decay, \ float weight_decay, \
const float gnorm_scale, const int n); \ const float gnorm_scale, const bool skip_zeros, const int n); \
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8) MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8)
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 2048, 8) MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 2048, 8)
@ -1838,7 +1838,7 @@ template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block
float* __restrict__ const quantiles1, \ float* __restrict__ const quantiles1, \
float* absmax1, \ float* absmax1, \
float weight_decay, \ float weight_decay, \
const float gnorm_scale, const int n); \ const float gnorm_scale, const bool skip_zeros, const int n); \
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8)

View File

@ -27,7 +27,7 @@ template<typename T, int OPTIMIZER>
__global__ void kOptimizer32bit2State(T* g, T* p, __global__ void kOptimizer32bit2State(T* g, T* p,
float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay, const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n); const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS> template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
@ -39,7 +39,7 @@ template<typename T, int OPTIMIZER>
__global__ void kOptimizer32bit1State(T* g, T* p, __global__ void kOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm, const float max_unorm, const float param_norm, float* state1, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float eps, const float weight_decay, const float beta1, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n); const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int OPTIMIZER> template<typename T, int OPTIMIZER>
__global__ void __global__ void
@ -90,7 +90,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ voi
T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2,
const float beta1, const float beta2, const float eps, const int step, const float lr, const float beta1, const float beta2, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const int n); float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit1StateBlockwise( template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit1StateBlockwise(
T* p, T* __restrict__ const g, unsigned char* state1, T* p, T* __restrict__ const g, unsigned char* state1,
@ -99,7 +99,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ voi
float* __restrict__ const quantiles1, float* __restrict__ const quantiles1,
float* absmax1, float* absmax1,
float weight_decay, float weight_decay,
const float gnorm_scale, const int n); const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n); template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n);

View File

@ -181,7 +181,7 @@ template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, flo
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay, const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n) const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n)
{ {
int blocks = n/4096; int blocks = n/4096;
blocks = n % 4096 == 0 ? blocks : blocks + 1; blocks = n % 4096 == 0 ? blocks : blocks + 1;
@ -194,7 +194,7 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
kOptimizer32bit2State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); kOptimizer32bit2State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
case MOMENTUM: case MOMENTUM:
@ -206,7 +206,7 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
kOptimizer32bit1State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, n); kOptimizer32bit1State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
} }
@ -259,7 +259,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g, template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g,
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n) float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)
{ {
int blocks = 0; int blocks = 0;
@ -269,7 +269,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
blocks = n/BLOCKSIZE_2STATE; blocks = n/BLOCKSIZE_2STATE;
blocks = n % BLOCKSIZE_2STATE == 0 ? blocks : blocks + 1; blocks = n % BLOCKSIZE_2STATE == 0 ? blocks : blocks + 1;
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr, kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr,
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, n); quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
case MOMENTUM: case MOMENTUM:
@ -277,7 +277,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
blocks = n/BLOCKSIZE_1STATE; blocks = n/BLOCKSIZE_1STATE;
blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1; blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1;
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr, kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
quantiles1, absmax1, weight_decay, gnorm_scale, n); quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
} }
@ -313,7 +313,7 @@ template void dequantizeBlockwise<float>(float *code, unsigned char *A, float *a
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \ template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float eps, const float weight_decay, \ const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const int n); const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
MAKE_optimizer32bit(ADAM, half) MAKE_optimizer32bit(ADAM, half)
MAKE_optimizer32bit(ADAM, float) MAKE_optimizer32bit(ADAM, float)
@ -342,7 +342,7 @@ MAKE_optimizerStatic8bit(RMSPROP, float)
#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ #define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \ template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n); \ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \
MAKE_optimizerStatic8bitBlockwise(half, ADAM); MAKE_optimizerStatic8bitBlockwise(half, ADAM);
MAKE_optimizerStatic8bitBlockwise(float, ADAM); MAKE_optimizerStatic8bitBlockwise(float, ADAM);

View File

@ -49,7 +49,7 @@ template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, flo
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
float beta1, float beta2, float eps, float weight_decay, float beta1, float beta2, float eps, float weight_decay,
int step, float lr, const float gnorm_scale, int n); int step, float lr, const float gnorm_scale, bool skip_zeros, int n);
template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2, template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2,
float *unorm, float max_unorm, float param_norm, float *unorm, float max_unorm, float param_norm,
@ -62,7 +62,8 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, unsigne
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g, template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g,
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n); float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale,
bool skip_zeros, int n);
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n); template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n);

View File

@ -20,8 +20,8 @@ void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimate
void fname##32bit_g##gbits(gtype *g, gtype *p, \ void fname##32bit_g##gbits(gtype *g, gtype *p, \
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float eps, const float weight_decay, \ const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, float gnorm_scale, const int n) \ const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n) \
{ optimizer32bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); } \ { optimizer32bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
MAKE_FUNC32(momentum, MOMENTUM, float, 32) MAKE_FUNC32(momentum, MOMENTUM, float, 32)
MAKE_FUNC32(momentum, MOMENTUM, half, 16) MAKE_FUNC32(momentum, MOMENTUM, half, 16)
@ -53,8 +53,8 @@ MAKE_FUNC8(rmsprop, RMSPROP, half, 16)
#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \ #define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \ void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n)\ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\
{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, n); }\ { optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\
MAKE_BLOCKWISE8(adam, ADAM, half, 16) MAKE_BLOCKWISE8(adam, ADAM, half, 16)
MAKE_BLOCKWISE8(adam, ADAM, float, 32) MAKE_BLOCKWISE8(adam, ADAM, float, 32)
@ -93,8 +93,8 @@ extern "C"
void c##name##32bit_g##gbits(gtype *g, gtype *p, \ void c##name##32bit_g##gbits(gtype *g, gtype *p, \
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float eps, const float weight_decay, \ const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const int n) \ const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \
{ name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); } \ { name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
MAKE_CFUNC32(adam, float, 32) MAKE_CFUNC32(adam, float, 32)
MAKE_CFUNC32(adam, half, 16) MAKE_CFUNC32(adam, half, 16)
@ -110,7 +110,7 @@ extern "C"
float eps, int step, float lr, \ float eps, int step, float lr, \
float* quantiles1, float* quantiles2, \ float* quantiles1, float* quantiles2, \
float* max1, float* max2, float* new_max1, float* new_max2, \ float* max1, float* max2, float* new_max1, float* new_max2, \
float weight_decay, float gnorm_scale, int n) \ float weight_decay, float gnorm_scale, bool skip_zeros, int n) \
{ \ { \
name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
@ -126,8 +126,8 @@ extern "C"
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ #define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \ void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n) \ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \
{ fname##_8bit_blockwise_fp##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, n); } \ { fname##_8bit_blockwise_fp##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \
MAKE_CBLOCKWISE8(adam, ADAM, half, 16) MAKE_CBLOCKWISE8(adam, ADAM, half, 16)
MAKE_CBLOCKWISE8(adam, ADAM, float, 32) MAKE_CBLOCKWISE8(adam, ADAM, float, 32)

View File

@ -141,6 +141,7 @@ def test_global_config(dim1, dim2, gtype):
eps = 1e-8 eps = 1e-8
bnb.optim.GlobalOptimManager.get_instance().initialize() bnb.optim.GlobalOptimManager.get_instance().initialize()
bnb.optim.GlobalOptimManager.get_instance().override_config(p2, 'skip_zeros', True)
bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8) bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8)
bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3]) bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
@ -155,6 +156,8 @@ def test_global_config(dim1, dim2, gtype):
else: else:
atol, rtol = 1e-4, 1e-3 atol, rtol = 1e-4, 1e-3
original_p2 = p2[mask].clone()
for i in range(50): for i in range(50):
g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
@ -163,11 +166,32 @@ def test_global_config(dim1, dim2, gtype):
p2.grad = g2 p2.grad = g2
p3.grad = g3 p3.grad = g3
if i > 30 and i % 10 == 0:
g1.data[mask] = 0.0
g2.data[mask] = 0.0
p1.grad = g1
p2.grad = g2
original_p1 = p1[mask].clone()
original_p2 = p2[mask].clone()
og_s1 = adam2.state[p2]['state1'][mask].clone()
og_s2 = adam2.state[p2]['state2'][mask].clone()
og_s11 = adam2.state[p1]['state1'][mask].clone()
og_s21 = adam2.state[p1]['state2'][mask].clone()
adam2.step() adam2.step()
assert adam2.state[p3]['state1'].dtype == torch.uint8 assert adam2.state[p3]['state1'].dtype == torch.uint8
assert adam2.state[p3]['state2'].dtype == torch.uint8 assert adam2.state[p3]['state2'].dtype == torch.uint8
if i > 30 and i % 10 == 0:
torch.testing.assert_allclose(original_p2, p2[mask])
torch.testing.assert_allclose(adam2.state[p2]['state1'][mask], og_s1)
torch.testing.assert_allclose(adam2.state[p2]['state2'][mask], og_s2)
assert ((p1[mask]- original_p1)==0.0).sum() < p1.numel()
assert ((adam2.state[p1]['state1'][mask]- og_s11)==0.0).sum() == 0.0
assert ((adam2.state[p1]['state2'][mask]- og_s21)==0.0).sum() == 0.0
dim1 = [1024] dim1 = [1024]