From 21723f796a3951e56b77460e7d572c76619b773f Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sat, 29 Apr 2023 21:52:47 -0700 Subject: [PATCH] 4-bit draft. --- bitsandbytes/functional.py | 22 +++- csrc/kernels.cu | 222 +++++++++++++++++++++++++++++++++---- csrc/kernels.cuh | 1 + csrc/ops.cu | 18 +++ csrc/ops.cuh | 1 + csrc/pythonInterface.c | 6 + tests/test_functional.py | 30 ++++- 7 files changed, 273 insertions(+), 27 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index b5c622b..f725c1c 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1380,10 +1380,15 @@ def cutlass3_gemm( out: Tensor = None, transposed_A=False, transposed_B=False, + state=None ): - sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) + #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) + if state is None: + Bshape = B.shape + else: + Bshape = state[1] if out is None: - out = torch.zeros(size=sout, dtype=A.dtype, device=A.device) + out = torch.zeros(size=(A.shape[0], Bshape[1]), dtype=A.dtype, device=A.device) sA = A.shape sB = B.shape @@ -1456,7 +1461,13 @@ def cutlass3_gemm( # [km, nk -> mn] #lda = ldb = ldc = 1 #lda = 1 - #print(m, n, k, lda, ldb, ldc) + if state is not None: + m = Bshape[0] + k = Bshape[1] + lda = Bshape[1] + ldc = Bshape[0] + ldb = (ldb+1)//2 + print(m, n, k, lda, ldb, ldc) is_on_gpu([B, A, out]) m = ct.c_int32(m) n = ct.c_int32(n) @@ -1464,7 +1475,10 @@ def cutlass3_gemm( lda = ct.c_int32(lda) ldb = ct.c_int32(ldb) ldc = ct.c_int32(ldc) - if A.dtype == torch.float32: + + if B.dtype == torch.uint8: + lib.cgemm_4bit_inference(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) + elif A.dtype == torch.float32: lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) elif A.dtype == torch.float16: lib.cgemm_host_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index a5697ee..53a183d 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -69,6 +69,27 @@ __device__ float dDequantizeFP4(unsigned char val, float absmax) } } +__device__ float d2DequantizeFP4(unsigned char val) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0110) == 0) + { + // subnormal + if((val & 0b0001) == 0) + return 0.0f; + else + return sign*0.0625f; + } + else + { + // normal + float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f); + float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f; + + return sign*exponent*fraction; + } +} + __device__ float dDequantizeFP4Tree(unsigned char val, float absmax) { float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; @@ -145,7 +166,7 @@ __device__ unsigned char dQuantizeFP4(float x) return 0b0000+sign; } -__device__ float dDequantizeNF4(unsigned char val, float absmax) +__device__ half dhDequantizeNF4(unsigned char val) { // the values for this tree was generated by test_normal_map_tree // in the file tests/test_functional.py @@ -153,49 +174,103 @@ __device__ float dDequantizeNF4(unsigned char val, float absmax) if((val & 0b0100) == 4) // 1 if((val & 0b0010) == 2) // 11 if((val & 0b0001) == 1) // 111 - return 1.0f*absmax; + return 1.0f; else - return 0.7229568362236023f*absmax; + return 0.7229568362236023f; else if((val & 0b0001) == 1) // 110 - return 0.5626170039176941f*absmax; + return 0.5626170039176941f; else - return 0.44070982933044434f*absmax; + return 0.44070982933044434f; else if((val & 0b0010) == 2) //10 if((val & 0b0001) == 1) // 101 - return 0.33791524171829224f*absmax; + return 0.33791524171829224f; else - return 0.24611230194568634f*absmax; + return 0.24611230194568634f; else if((val & 0b0001) == 1) // 100 - return 0.16093020141124725f*absmax; + return 0.16093020141124725f; else - return 0.07958029955625534f*absmax; + return 0.07958029955625534f; else if((val & 0b0100) == 4) // 0 if((val & 0b0010) == 2) //01 if((val & 0b0001) == 1) // 011 - return 0.0f*absmax; + return 0.0f; else - return -0.09105003625154495f*absmax; + return -0.09105003625154495f; else if((val & 0b0001) == 1) // 010 - return -0.18477343022823334f*absmax; + return -0.18477343022823334f; else - return -0.28444138169288635f*absmax; + return -0.28444138169288635f; else if((val & 0b0010) == 2) //00 if((val & 0b0001) == 1) // 001 - return -0.39491748809814453f*absmax; + return -0.39491748809814453f; else - return -0.5250730514526367f*absmax; + return -0.5250730514526367f; else if((val & 0b0001) == 1) // 000 - return -0.6961928009986877f*absmax; + return -0.6961928009986877f; else - return -1.0f*absmax; + return -1.0f; + +} + +__device__ float dDequantizeNF4(unsigned char val) +{ + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if((val & 0b1000) == 8) + if((val & 0b0100) == 4) // 1 + if((val & 0b0010) == 2) // 11 + if((val & 0b0001) == 1) // 111 + return 1.0f; + else + return 0.7229568362236023f; + else + if((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; + else + return 0.44070982933044434f; + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; + else + return 0.24611230194568634f; + else + if((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; + else + return 0.07958029955625534f; + + else + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 011 + return 0.0f; + else + return -0.09105003625154495f; + else + if((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; + else + return -0.28444138169288635f; + else + if((val & 0b0010) == 2) //00 + if((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; + else + return -0.5250730514526367f; + else + if((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; + else + return -1.0f; } @@ -800,8 +875,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH; j++) { - vals[j*2] = dDequantizeNF4(qvals[j] >> 4, local_abs_max); - vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F, local_abs_max); + vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max; + vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max; } break; } @@ -2947,7 +3022,7 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * //// 9. write outputs to matmul output matrix //} -template __device__ inline void vector_load(T *local, T * __restrict__ const buffer, int idx, int limit_base, int limit) +template __device__ inline void vector_load(T *local, T * __restrict__ const buffer, int idx, int limit_base, int limit, float zero_value = 0.0f) { if(limit_base + ITEMS <= limit) reinterpret_cast(local)[0] = reinterpret_cast(buffer)[idx/ITEMS]; @@ -2958,7 +3033,7 @@ template __device__ inline void vector_l if(limit_base + k < limit) local[k] = buffer[idx+k]; else - local[k] = 0.0f; + local[k] = (T)zero_value; } } } @@ -3024,6 +3099,109 @@ template __global__ void gemm_device(int M, out[col_offset + threadIdx.x ] = smem_C[threadIdx.x]; } +template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage reduce; + int col_offset = blockIdx.x *8; + + T local_A[32]; + unsigned char local_B_4bit[16]; + T local_B[32]; + T local_C[8]; + + __shared__ T smem_C[8]; + + if(threadIdx.x < 8) + smem_C[threadIdx.x] = T(0); + __syncthreads(); + + #pragma unroll 8 + for(int k = 0; k < 8; k++) + local_C[k] = T(0); + + + for(int idx = threadIdx.x*32; idx < K; idx+=blockDim.x*32) + { + + // we load only 8 values per iteration from A, so we + // need to do 4 loads for every single load from B + // for B, we have packed values, so the 16 8-bit values + // turn into 32 4-bit values to 4x 4 loads turns into 4x 8 loads + vector_load(local_A, A, idx, idx, K); + vector_load(&(local_A[8]), A, idx+8, idx+8, K); + vector_load(&(local_A[16]), A, idx+16, idx+16, K); + vector_load(&(local_A[24]), A, idx+24, idx+24, K); + + for(int col = 0; col < 8; col++) + { + if((col + col_offset) >= M){ break; } + + int offset_B = (col_offset+col)*ldb; + // 0111 -> 0.0f in NF4 + // since we have packed 8-bits, we need cat(0b0111, 0b0111) = 0b01110111 + vector_load(local_B_4bit, B, (offset_B+idx+1)/2, (idx+1)/2, (K+1)/2, 0b01110111); + + int absidx = (idx + offset_B)/blocksize; + half local_absmax = __ldg(&(absmax[absidx])); + //for(int k = 0; k < 16; k++) + //printf("%i %i ", local_B_4bit[k] >> 4, local_B_4bit[k] & 0x0F); + //printf("\n"); + + //vector_load(local_A, A, idx, idx, K); + + #pragma unroll 16 + for(int k = 0; k < 16; k++) + { + + //if(local_B_4bit[k ] != 0b01110111) + //printf("(%i %i %i) %i -> %f, %i -> %f\n", threadIdx.x , k, K, local_B_4bit[k ] >> 4, dDequantizeNF4(local_B_4bit[k ] >> 4, local_absmax), + //local_B_4bit[k ] & 0x0F, dDequantizeNF4(local_B_4bit[k ] & 0x0F, local_absmax)); + //local_B[k*2] = d2DequantizeFP4(local_B_4bit[k] >> 4);//*local_absmax; + //local_B[k*2 + 1] = d2DequantizeFP4(local_B_4bit[k] & 0x0F);//*local_absmax; + local_B[k*2] = (half)(local_B_4bit[k] >> 4)*local_absmax; + local_B[k*2 + 1] = (half)(local_B_4bit[k] & 0x0F)*local_absmax; + //local_B[k*2] = (half)dDequantizeNF4(local_B_4bit[k ] >> 4);//*local_absmax; + //local_B[k*2 + 1] = (half)dDequantizeNF4(local_B_4bit[k ] & 0x0F);//*local_absmax; + } + + #pragma unroll 32 + //for(int k = 0; k < 8; k++) + for(int k = 0; k < 32; k++) + { + local_C[col] += local_A[k]*local_B[k]; + //if((float)local_A[k] != 0.0 && (float)local_B[k] != 0.0) + //if((float)local_B[k] != 0.0) + //printf("%i %i %i %i %f*%f\n", threadIdx.x, k, col, (float)local_A[k], (float)local_B[k]); + } + } + } + + #pragma unroll 8 + for(int k = 0; k < 8; k++) + { + local_C[k] = BlockReduce(reduce).Reduce(local_C[k], cub::Sum()); + __syncthreads(); + } + + if(threadIdx.x == 0) + { + #pragma unroll 8 + for(int k = 0; k < 8; k++) + smem_C[k] = local_C[k]; + } + else if(threadIdx.x >= 32) + // early return for unused warps + return; + + __syncwarp(); + + + if(threadIdx.x < 8 && col_offset + threadIdx.x < M) + out[col_offset + threadIdx.x ] = smem_C[threadIdx.x]; +} + //#define ROWS 2 //template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc) //{ @@ -3207,6 +3385,8 @@ template __global__ void gemm_device(int M, int N, int K, half * template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); + //template __global__ void kMatmul_inference_4bit(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB); template __global__ void with_staging_unified<2>(float const* global_in, float * global_out, size_t size, size_t batch_sz); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index aab7b95..4951031 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -139,5 +139,6 @@ template __global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz); template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); +template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); #endif diff --git a/csrc/ops.cu b/csrc/ops.cu index 2219690..07e7107 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -695,10 +695,28 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out gemm_device<<< num_blocks, dimBlock, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); } +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + dim3 dimBlock(128); + int num_blocks = (m+7)/8; + + cout << num_blocks << endl; + cout << lda << endl; + cout << ldb << endl; + cout << ldc << endl; + + cout << m << endl; + cout << n << endl; + cout << k << endl; + kgemm_4bit_inference<<< num_blocks, dimBlock, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); +} + //============================================================== // TEMPLATE DEFINITIONS //============================================================== +template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index ffc9e87..8919c60 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -191,6 +191,7 @@ template void extractOutliers(char * A, int *idx, char *out, int id void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); void pipeline_test(float *A, float *B, size_t n, size_t batch_size); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 1ece3e6..bdf821c 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -25,6 +25,9 @@ void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, in void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) { gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 16); } +void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) +{ gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } + #define MAKE_FUNC32(fname, oname, gtype, gbits) \ void fname##32bit_g##gbits(gtype *g, gtype *p, \ @@ -319,6 +322,9 @@ extern "C" void cgemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) { gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); } + void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) + { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } + #endif void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } diff --git a/tests/test_functional.py b/tests/test_functional.py index b256af9..f58cd43 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2352,8 +2352,8 @@ def test_normal_map_tree(): print(pivots) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) -#@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) +#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) +@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) def test_cutlass3_gemm(dtype): for i in range(1): #A = torch.rand(2, 4092, dtype=dtype, device='cuda') @@ -2373,6 +2373,32 @@ def test_cutlass3_gemm(dtype): torch.testing.assert_close(C1, C2, atol=1e-05, rtol=0.005) +#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) +@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) +def test_gemm_4bit(dtype): + for i in range(1): + #A = torch.rand(2, 4092, dtype=dtype, device='cuda') + #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') + #torch.random.manual_seed(17) + A = torch.rand(1, 4096, dtype=dtype, device='cuda') + B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') + + #print('') + #print(A) + #print(B) + + qB, state = F.quantize_nf4(B) + F.dequantize_nf4(qB, state) + + + C1 = torch.matmul(A, B.t()) + #C1 = bnb.matmul_4bit(A, qB.t(), state) + C2 = F.cutlass3_gemm(A, qB.t(), state=state) + #print(C1) + #print(C2) + + #torch.testing.assert_close(C1, C2, atol=1e-5, rtol=0.005) + def test_pipeline_func(): a = torch.rand(2, 4).cuda()