forked from mrq/bitsandbytes-rocm
Added skip_zeros; tests are passing.
This commit is contained in:
parent
bb34fd50a1
commit
a6eae2e7f2
46
Makefile
46
Makefile
|
@ -15,29 +15,31 @@ INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/inclu
|
||||||
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcuda -lcublas -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
|
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcuda -lcublas -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
|
||||||
|
|
||||||
# NVIDIA NVCC compilation flags
|
# NVIDIA NVCC compilation flags
|
||||||
COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler
|
#COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler
|
||||||
COMPUTE_CAPABILITY += -gencode arch=compute_37,code=sm_37 # Kepler
|
#COMPUTE_CAPABILITY += -gencode arch=compute_37,code=sm_37 # Kepler
|
||||||
COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell
|
#COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell
|
||||||
COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell
|
#COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell
|
||||||
COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal
|
#COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal
|
||||||
COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
|
#COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
|
||||||
COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
|
#COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
|
||||||
COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
|
#COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
|
||||||
COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
|
#COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
|
||||||
|
#
|
||||||
|
## CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not
|
||||||
|
#CC_CUDA92 := -gencode arch=compute_30,code=sm_30
|
||||||
|
#
|
||||||
|
## Later versions of CUDA support the new architectures
|
||||||
|
#CC_CUDA10x := -gencode arch=compute_30,code=sm_30
|
||||||
|
#CC_CUDA10x += -gencode arch=compute_75,code=sm_75
|
||||||
|
#
|
||||||
|
#CC_CUDA110 := -gencode arch=compute_75,code=sm_75
|
||||||
|
#CC_CUDA110 += -gencode arch=compute_80,code=sm_80
|
||||||
|
#
|
||||||
|
#CC_CUDA11x := -gencode arch=compute_75,code=sm_75
|
||||||
|
#CC_CUDA11x += -gencode arch=compute_80,code=sm_80
|
||||||
|
#CC_CUDA11x += -gencode arch=compute_86,code=sm_86
|
||||||
|
|
||||||
# CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not
|
COMPUTE_CAPABILITY := -gencode arch=compute_70,code=sm_70 # Volta
|
||||||
CC_CUDA92 := -gencode arch=compute_30,code=sm_30
|
|
||||||
|
|
||||||
# Later versions of CUDA support the new architectures
|
|
||||||
CC_CUDA10x := -gencode arch=compute_30,code=sm_30
|
|
||||||
CC_CUDA10x += -gencode arch=compute_75,code=sm_75
|
|
||||||
|
|
||||||
CC_CUDA110 := -gencode arch=compute_75,code=sm_75
|
|
||||||
CC_CUDA110 += -gencode arch=compute_80,code=sm_80
|
|
||||||
|
|
||||||
CC_CUDA11x := -gencode arch=compute_75,code=sm_75
|
|
||||||
CC_CUDA11x += -gencode arch=compute_80,code=sm_80
|
|
||||||
CC_CUDA11x += -gencode arch=compute_86,code=sm_86
|
|
||||||
|
|
||||||
|
|
||||||
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR)
|
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR)
|
||||||
|
|
|
@ -336,7 +336,7 @@ class Optimizer2State(Optimizer8bit):
|
||||||
if state['state1'].dtype == torch.float:
|
if state['state1'].dtype == torch.float:
|
||||||
F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
|
F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
|
||||||
state['state2'], config['betas'][1], config['weight_decay'], gnorm_scale,
|
state['state2'], config['betas'][1], config['weight_decay'], gnorm_scale,
|
||||||
state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
|
state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'], skip_zeros=config['skip_zeros'])
|
||||||
|
|
||||||
elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
|
elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
|
||||||
F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
|
F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
|
||||||
|
@ -352,7 +352,7 @@ class Optimizer2State(Optimizer8bit):
|
||||||
F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
|
F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
|
||||||
config['eps'], step, config['lr'],
|
config['eps'], step, config['lr'],
|
||||||
state['qmap1'], state['qmap2'], state['absmax1'], state['absmax2'],
|
state['qmap1'], state['qmap2'], state['absmax1'], state['absmax2'],
|
||||||
config['weight_decay'], gnorm_scale=gnorm_scale)
|
config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros'])
|
||||||
|
|
||||||
|
|
||||||
class Optimizer1State(Optimizer8bit):
|
class Optimizer1State(Optimizer8bit):
|
||||||
|
@ -450,7 +450,7 @@ class Optimizer1State(Optimizer8bit):
|
||||||
F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
|
F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
|
||||||
None, 0.0, config['weight_decay'], gnorm_scale,
|
None, 0.0, config['weight_decay'], gnorm_scale,
|
||||||
state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'],
|
state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'],
|
||||||
skip_zeros=False)
|
skip_zeros=config['skip_zeros'])
|
||||||
|
|
||||||
elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
|
elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
|
||||||
F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
|
F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
|
||||||
|
@ -463,4 +463,4 @@ class Optimizer1State(Optimizer8bit):
|
||||||
F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
|
F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
|
||||||
config['eps'], step, config['lr'],
|
config['eps'], step, config['lr'],
|
||||||
state['qmap1'], None, state['absmax1'], None,
|
state['qmap1'], None, state['absmax1'], None,
|
||||||
config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=False)
|
config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros'])
|
||||||
|
|
116
csrc/kernels.cu
116
csrc/kernels.cu
|
@ -715,9 +715,12 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
|
||||||
switch(OPTIMIZER)
|
switch(OPTIMIZER)
|
||||||
{
|
{
|
||||||
case ADAM:
|
case ADAM:
|
||||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
|
if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
|
||||||
s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
|
{
|
||||||
p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2))));
|
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
|
||||||
|
s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
|
||||||
|
p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2))));
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -865,21 +868,24 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
|
||||||
# pragma unroll 4
|
# pragma unroll 4
|
||||||
for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
|
for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
|
||||||
{
|
{
|
||||||
switch(OPTIMIZER)
|
if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
|
||||||
{
|
{
|
||||||
case MOMENTUM:
|
switch(OPTIMIZER)
|
||||||
if(step == 1)
|
{
|
||||||
s1_vals[j] = (float)g_vals[j];
|
case MOMENTUM:
|
||||||
else
|
if(step == 1)
|
||||||
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);
|
s1_vals[j] = (float)g_vals[j];
|
||||||
|
else
|
||||||
|
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);
|
||||||
|
|
||||||
p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
|
p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
|
||||||
break;
|
break;
|
||||||
case RMSPROP:
|
case RMSPROP:
|
||||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
|
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
|
||||||
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
|
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
@ -1469,11 +1475,14 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
||||||
{
|
{
|
||||||
g_val = float(g_vals[j]);
|
g_val = float(g_vals[j]);
|
||||||
g_val *= gnorm_scale;
|
g_val *= gnorm_scale;
|
||||||
s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
|
if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
|
||||||
s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val));
|
{
|
||||||
|
s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
|
||||||
|
s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val));
|
||||||
|
|
||||||
s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE];
|
s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE];
|
||||||
s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val));
|
s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val));
|
||||||
|
}
|
||||||
|
|
||||||
new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j]));
|
new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j]));
|
||||||
new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j]));
|
new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j]));
|
||||||
|
@ -1509,9 +1518,12 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
||||||
# pragma unroll N_PER_TH
|
# pragma unroll N_PER_TH
|
||||||
for(unsigned int j = 0; j < N_PER_TH; j++)
|
for(unsigned int j = 0; j < N_PER_TH; j++)
|
||||||
{
|
{
|
||||||
g_vals[j] = (T)(((float)g_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
|
if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
|
||||||
if(weight_decay > 0.0f)
|
{
|
||||||
g_vals[j] = ((float)g_vals[j])*(1.0f-(lr*weight_decay));
|
g_vals[j] = (T)(((float)g_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
|
||||||
|
if(weight_decay > 0.0f)
|
||||||
|
g_vals[j] = ((float)g_vals[j])*(1.0f-(lr*weight_decay));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// store: 0.85/1.44 -> 2.48/1.57
|
// store: 0.85/1.44 -> 2.48/1.57
|
||||||
|
@ -1623,23 +1635,26 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
||||||
{
|
{
|
||||||
g_val = float(g_vals[j]);
|
g_val = float(g_vals[j]);
|
||||||
g_val *= gnorm_scale;
|
g_val *= gnorm_scale;
|
||||||
if(weight_decay > 0.0f)
|
if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
|
||||||
g_val += ((float)p_vals[j])*weight_decay;
|
{
|
||||||
|
if(weight_decay > 0.0f)
|
||||||
|
g_val += ((float)p_vals[j])*weight_decay;
|
||||||
|
|
||||||
s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
|
s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
|
||||||
|
|
||||||
switch(OPTIMIZER)
|
switch(OPTIMIZER)
|
||||||
{
|
{
|
||||||
case MOMENTUM:
|
case MOMENTUM:
|
||||||
if(step == 1)
|
if(step == 1)
|
||||||
s1_vals[j] = g_val;
|
s1_vals[j] = g_val;
|
||||||
else
|
else
|
||||||
s1_vals[j] = (s1_vals[j]*beta1) + g_val;
|
s1_vals[j] = (s1_vals[j]*beta1) + g_val;
|
||||||
break;
|
break;
|
||||||
case RMSPROP:
|
case RMSPROP:
|
||||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
|
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j]));
|
new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j]));
|
||||||
}
|
}
|
||||||
|
@ -1662,16 +1677,19 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
||||||
# pragma unroll N_PER_TH
|
# pragma unroll N_PER_TH
|
||||||
for(unsigned int j = 0; j < N_PER_TH; j++)
|
for(unsigned int j = 0; j < N_PER_TH; j++)
|
||||||
{
|
{
|
||||||
switch(OPTIMIZER)
|
if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
|
||||||
{
|
{
|
||||||
case MOMENTUM:
|
switch(OPTIMIZER)
|
||||||
p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
|
{
|
||||||
break;
|
case MOMENTUM:
|
||||||
case RMSPROP:
|
p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
|
||||||
g_val = g_vals[j];
|
break;
|
||||||
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
|
case RMSPROP:
|
||||||
break;
|
g_val = g_vals[j];
|
||||||
}
|
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// store: 0.85/1.44 -> 2.48/1.57
|
// store: 0.85/1.44 -> 2.48/1.57
|
||||||
|
|
|
@ -110,7 +110,7 @@ extern "C"
|
||||||
float eps, int step, float lr, \
|
float eps, int step, float lr, \
|
||||||
float* quantiles1, float* quantiles2, \
|
float* quantiles1, float* quantiles2, \
|
||||||
float* max1, float* max2, float* new_max1, float* new_max2, \
|
float* max1, float* max2, float* new_max1, float* new_max2, \
|
||||||
float weight_decay, float gnorm_scale, bool skip_zeros, int n) \
|
float weight_decay, float gnorm_scale, int n) \
|
||||||
{ \
|
{ \
|
||||||
name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
|
name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
|
||||||
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
|
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
|
||||||
|
|
Loading…
Reference in New Issue
Block a user