From 3aef78342aec4fff1922c0c2cdd83bdda928b536 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Fri, 28 Apr 2023 17:34:08 -0700 Subject: [PATCH] Added template refactor. --- bitsandbytes/functional.py | 4 +--- csrc/kernels.cu | 23 ++++++++++------------- csrc/kernels.cuh | 6 +----- csrc/ops.cu | 11 +++-------- csrc/ops.cuh | 7 +------ csrc/pythonInterface.c | 19 ++++--------------- 6 files changed, 20 insertions(+), 50 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 774e954..da4e66c 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1464,9 +1464,7 @@ def cutlass3_gemm( lda = ct.c_int32(lda) ldb = ct.c_int32(ldb) ldc = ct.c_int32(ldc) - alpha = ct.c_float(1.0) - beta = ct.c_float(0.0) - lib.ccutlass_gemm(m, n, k, alpha, get_ptr(A), lda, get_ptr(B), ldb, beta, get_ptr(out), ldc) + lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) return out diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 91169dd..45db448 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2949,22 +2949,18 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * #define ROWS 2 -__global__ void gemm_device(int M, int N, int K, - float const* A, - float* B, - float * out, int lda, int ldb, int ldc, - float alpha, float beta) +template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc) { // 0. We want to fill a 8x128 tile for a thread block so we have 8x16 tile for each warp // 1. Load dataB into register // 2. Dequantize B // 3. Fetch data from A and multiply - typedef cub::BlockLoad LoadA; + typedef cub::BlockLoad LoadA; //__shared__ typename LoadA::TempStorage loada; - typedef cub::BlockLoad LoadB; + typedef cub::BlockLoad LoadB; //__shared__ typename LoadB::TempStorage loadb; - typedef cub::BlockReduce BlockReduce; + typedef cub::BlockReduce BlockReduce; // Allocate shared memory for BlockReduce //__shared__ typename BlockReduce::TempStorage reduce; @@ -2975,16 +2971,16 @@ __global__ void gemm_device(int M, int N, int K, } temp_storage; - float dataA[4]; - float local_B[4]; - float local_accC[ROWS]; + T dataA[4]; + T local_B[4]; + T local_accC[ROWS]; int valid_items = 0; const int warp_id = threadIdx.x/32; const int warp_lane = threadIdx.x % 32; const int col_offset = blockIdx.x * 8; - __shared__ float tileA[ROWS*1024]; - __shared__ float accumulatorC[ROWS*8]; + __shared__ T tileA[ROWS*1024]; + __shared__ T accumulatorC[ROWS*8]; //#pragma unroll 8 //for(int i = 0; i < 8; i++) @@ -3128,6 +3124,7 @@ __global__ void with_staging_unified(float const* global_in, float * global_out, // TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, // TC * out, CStride dC, CBlockLayout , CThreadLayout tC, // half alpha, half beta); +template __global__ void gemm_device(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc); //template __global__ void kMatmul_inference_4bit(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 55397e7..900af90 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -138,10 +138,6 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * template __global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz); -__global__ void gemm_device(int M, int N, int K, - float const* A, - float * B, - float * out, int lda, int ldb, int ldc, - float alpha, float beta); +template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc); #endif diff --git a/csrc/ops.cu b/csrc/ops.cu index dd8fade..6aaa241 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -675,12 +675,7 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size) -void gemm_host(int m, int n, int k, - float alpha, - float const* A, int lda, - float * B, int ldb, - float beta, - float * C, int ldc) +template void gemm_host(int m, int n, int k, T const* A, T* B, T * out, int lda, int ldb, int ldc) { dim3 dimBlock(256); @@ -699,14 +694,14 @@ void gemm_host(int m, int n, int k, (m, n, k, A, B, - C, lda, ldb, ldc, - alpha, beta); + out, lda, ldb, ldc); } //============================================================== // TEMPLATE DEFINITIONS //============================================================== +template void gemm_host(int m, int n, int k, float const* A, float* B, float * out, int lda, int ldb, int ldc); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 2f71966..b7ef9a3 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -190,12 +190,7 @@ template void extractOutliers(char * A, int *idx, char *out, int id void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); -void gemm_host(int m, int n, int k, - float alpha, - float const* A, int ldA, - float * B, int ldB, - float beta, - float * C, int ldC); +template void gemm_host(int m, int n, int k, T const* A, T* B, T * out, int lda, int ldb, int ldc); void pipeline_test(float *A, float *B, size_t n, size_t batch_size); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 6ec5501..a7c4787 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -20,14 +20,8 @@ void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimat void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles(A, code, offset, n); } -void -cppgemm(int m, int n, int k, - float alpha, - float const* A, int ldA, - float * B, int ldB, - float beta, - float * C, int ldC) -{ gemm_host(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);} +void gemm_host_fp32(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc) +{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc); } #define MAKE_FUNC32(fname, oname, gtype, gbits) \ @@ -317,13 +311,8 @@ extern "C" void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); } void cpipeline_test(float *A, float *B, size_t n, size_t batch_size){ pipeline_test(A, B, n, batch_size); } - void ccutlass_gemm(int m, int n, int k, - float alpha, - float const* A, int ldA, - float * B, int ldB, - float beta, - float * C, int ldC) - { cppgemm(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);} + void cgemm_host_fp32(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc) + { gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); } #endif void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }