forked from mrq/bitsandbytes-rocm
Added skip_zeros; tests are passing.
This commit is contained in:
parent
bb34fd50a1
commit
a6eae2e7f2
46
Makefile
46
Makefile
|
@ -15,29 +15,31 @@ INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/inclu
|
|||
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcuda -lcublas -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
|
||||
|
||||
# NVIDIA NVCC compilation flags
|
||||
COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_37,code=sm_37 # Kepler
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
|
||||
#COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler
|
||||
#COMPUTE_CAPABILITY += -gencode arch=compute_37,code=sm_37 # Kepler
|
||||
#COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell
|
||||
#COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell
|
||||
#COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal
|
||||
#COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
|
||||
#COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
|
||||
#COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
|
||||
#COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
|
||||
#
|
||||
## CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not
|
||||
#CC_CUDA92 := -gencode arch=compute_30,code=sm_30
|
||||
#
|
||||
## Later versions of CUDA support the new architectures
|
||||
#CC_CUDA10x := -gencode arch=compute_30,code=sm_30
|
||||
#CC_CUDA10x += -gencode arch=compute_75,code=sm_75
|
||||
#
|
||||
#CC_CUDA110 := -gencode arch=compute_75,code=sm_75
|
||||
#CC_CUDA110 += -gencode arch=compute_80,code=sm_80
|
||||
#
|
||||
#CC_CUDA11x := -gencode arch=compute_75,code=sm_75
|
||||
#CC_CUDA11x += -gencode arch=compute_80,code=sm_80
|
||||
#CC_CUDA11x += -gencode arch=compute_86,code=sm_86
|
||||
|
||||
# CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not
|
||||
CC_CUDA92 := -gencode arch=compute_30,code=sm_30
|
||||
|
||||
# Later versions of CUDA support the new architectures
|
||||
CC_CUDA10x := -gencode arch=compute_30,code=sm_30
|
||||
CC_CUDA10x += -gencode arch=compute_75,code=sm_75
|
||||
|
||||
CC_CUDA110 := -gencode arch=compute_75,code=sm_75
|
||||
CC_CUDA110 += -gencode arch=compute_80,code=sm_80
|
||||
|
||||
CC_CUDA11x := -gencode arch=compute_75,code=sm_75
|
||||
CC_CUDA11x += -gencode arch=compute_80,code=sm_80
|
||||
CC_CUDA11x += -gencode arch=compute_86,code=sm_86
|
||||
COMPUTE_CAPABILITY := -gencode arch=compute_70,code=sm_70 # Volta
|
||||
|
||||
|
||||
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR)
|
||||
|
|
|
@ -336,7 +336,7 @@ class Optimizer2State(Optimizer8bit):
|
|||
if state['state1'].dtype == torch.float:
|
||||
F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
|
||||
state['state2'], config['betas'][1], config['weight_decay'], gnorm_scale,
|
||||
state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
|
||||
state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'], skip_zeros=config['skip_zeros'])
|
||||
|
||||
elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
|
||||
F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
|
||||
|
@ -352,7 +352,7 @@ class Optimizer2State(Optimizer8bit):
|
|||
F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
|
||||
config['eps'], step, config['lr'],
|
||||
state['qmap1'], state['qmap2'], state['absmax1'], state['absmax2'],
|
||||
config['weight_decay'], gnorm_scale=gnorm_scale)
|
||||
config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros'])
|
||||
|
||||
|
||||
class Optimizer1State(Optimizer8bit):
|
||||
|
@ -450,7 +450,7 @@ class Optimizer1State(Optimizer8bit):
|
|||
F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
|
||||
None, 0.0, config['weight_decay'], gnorm_scale,
|
||||
state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'],
|
||||
skip_zeros=False)
|
||||
skip_zeros=config['skip_zeros'])
|
||||
|
||||
elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
|
||||
F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
|
||||
|
@ -463,4 +463,4 @@ class Optimizer1State(Optimizer8bit):
|
|||
F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
|
||||
config['eps'], step, config['lr'],
|
||||
state['qmap1'], None, state['absmax1'], None,
|
||||
config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=False)
|
||||
config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros'])
|
||||
|
|
|
@ -715,9 +715,12 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
|
|||
switch(OPTIMIZER)
|
||||
{
|
||||
case ADAM:
|
||||
if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
|
||||
{
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
|
||||
s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
|
||||
p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2))));
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -864,6 +867,8 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
|
|||
|
||||
# pragma unroll 4
|
||||
for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
|
||||
{
|
||||
if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
|
||||
{
|
||||
switch(OPTIMIZER)
|
||||
{
|
||||
|
@ -881,6 +886,7 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
|
|||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items);
|
||||
|
@ -1469,11 +1475,14 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
{
|
||||
g_val = float(g_vals[j]);
|
||||
g_val *= gnorm_scale;
|
||||
if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
|
||||
{
|
||||
s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
|
||||
s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val));
|
||||
|
||||
s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE];
|
||||
s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val));
|
||||
}
|
||||
|
||||
new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j]));
|
||||
new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j]));
|
||||
|
@ -1508,11 +1517,14 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
// reduce: 2.67/1.69 -> 2.67/1.70
|
||||
# pragma unroll N_PER_TH
|
||||
for(unsigned int j = 0; j < N_PER_TH; j++)
|
||||
{
|
||||
if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
|
||||
{
|
||||
g_vals[j] = (T)(((float)g_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
|
||||
if(weight_decay > 0.0f)
|
||||
g_vals[j] = ((float)g_vals[j])*(1.0f-(lr*weight_decay));
|
||||
}
|
||||
}
|
||||
|
||||
// store: 0.85/1.44 -> 2.48/1.57
|
||||
__syncthreads();
|
||||
|
@ -1623,6 +1635,8 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
{
|
||||
g_val = float(g_vals[j]);
|
||||
g_val *= gnorm_scale;
|
||||
if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
|
||||
{
|
||||
if(weight_decay > 0.0f)
|
||||
g_val += ((float)p_vals[j])*weight_decay;
|
||||
|
||||
|
@ -1640,6 +1654,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j]));
|
||||
}
|
||||
|
@ -1661,6 +1676,8 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
// reduce: 2.67/1.69 -> 2.67/1.70
|
||||
# pragma unroll N_PER_TH
|
||||
for(unsigned int j = 0; j < N_PER_TH; j++)
|
||||
{
|
||||
if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
|
||||
{
|
||||
switch(OPTIMIZER)
|
||||
{
|
||||
|
@ -1673,6 +1690,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// store: 0.85/1.44 -> 2.48/1.57
|
||||
__syncthreads();
|
||||
|
|
|
@ -110,7 +110,7 @@ extern "C"
|
|||
float eps, int step, float lr, \
|
||||
float* quantiles1, float* quantiles2, \
|
||||
float* max1, float* max2, float* new_max1, float* new_max2, \
|
||||
float weight_decay, float gnorm_scale, bool skip_zeros, int n) \
|
||||
float weight_decay, float gnorm_scale, int n) \
|
||||
{ \
|
||||
name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
|
||||
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
|
||||
|
|
Loading…
Reference in New Issue
Block a user