diff --git a/Makefile b/Makefile index e114160..a377f65 100644 --- a/Makefile +++ b/Makefile @@ -25,6 +25,7 @@ FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include INCLUDE_10x := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include +INCLUDE_cutlass := -I $(ROOT_DIR)/dependencies/cutlass/include LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcurand -lcusparse -L $(CONDA_PREFIX)/lib # NVIDIA NVCC compilation flags @@ -61,7 +62,7 @@ CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90 all: $(BUILD_DIR) env - $(NVCC) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) + $(NVCC) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) @@ -100,6 +101,11 @@ cuda11x: $(BUILD_DIR) env $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) +cuda11x_cutlass: $(BUILD_DIR) env cutlass + $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(INCLUDE_cutlass) $(LIB) --output-directory $(BUILD_DIR) + $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(GPP) -std=c++20 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) + cuda12x: $(BUILD_DIR) env $(NVCC) $(CC_cublasLt111) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(CC_cublasLt111) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o @@ -121,6 +127,11 @@ env: @echo "LD_LIBRARY_PATH: $(LD_LIBRARY_PATH)" @echo "============================" +cutlass: + if [ ! -d "$(ROOT_DIR)/dependencies/cutlass" ]; then \ + git clone https://github.com/NVIDIA/cutlass.git $(ROOT_DIR)/dependencies/cutlass; \ + fi \ + $(BUILD_DIR): mkdir -p build mkdir -p dependencies diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 2d940be..5d2a58e 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2919,10 +2919,35 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * } } + +template __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB) +{ +// element-wise kernel +// 1. Load batch x k into registers +// 2. Load k x k into registers +// 3. dequantize and store in second pair of k x k +// 4. matmul +// 5. sum with cub +// 6. store outputs +// TC kernel +// use k warps per thread block +// 1. threadblock use read-only cache to read in register tile for A into shared memory +// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments +// 3. each warp reads a segment of values 16x32 from B +// 4. do dequantization from register of B into second pair of registers +// 5. store (4) into fragment +// 6. matmul aggregate into fragment C +// 7. aggreecate files of C into shared memroy block C +// 8. sum (7) +// 9. write outputs to matmul output matrix +} + + //============================================================== // TEMPLATE DEFINITIONS //============================================================== +template __global__ void kMatmul_inference_4bit(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index ed549cb..ecf3a09 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -9,6 +9,8 @@ #ifndef kernels #define kernels +template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB); + template__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n); __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n); diff --git a/csrc/ops.cu b/csrc/ops.cu index 76777ae..022f397 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -90,6 +90,17 @@ template void dequantizeBlockwise(float *code, unsign CUDA_CHECK_RETURN(cudaPeekAtLastError()); } + +void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB) +{ + int num_blocks = (colsB+32-1)/32; + kMatmul_inference_4bit<<>>(A, B, out, lda, ldb, rowsA, colsA, colsB); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *C, int lda, int ldb, int rowsA, int colsA, int colsB); + + template void optimizer32bit(T* g, T* p, float* state1, float* state2, float *unorm, float max_unorm, float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, @@ -653,6 +664,7 @@ template void extractOutliers(char * A, int *idx, char *out, int id CUDA_CHECK_RETURN(cudaPeekAtLastError()); } + //============================================================== // TEMPLATE DEFINITIONS //============================================================== diff --git a/csrc/ops.cuh b/csrc/ops.cuh index f73d4e0..137320b 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -183,4 +183,6 @@ template void spmm_coo_very_sparse_naive(int *max_count, template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); +void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); + #endif