From 5fab6734424a78a2a4594525386cd84feb67fb50 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 9 Jul 2023 21:06:01 -0700 Subject: [PATCH] Added fp32 compute type for gemv_4bit. --- bitsandbytes/functional.py | 96 ++++++-------------------------------- csrc/kernels.cu | 31 ++++++++---- csrc/kernels.cuh | 2 +- csrc/ops.cu | 10 ++-- csrc/ops.cuh | 2 +- csrc/pythonInterface.c | 10 +++- tests/test_functional.py | 8 ++-- 7 files changed, 55 insertions(+), 104 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index c5514ed..1972462 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1464,6 +1464,9 @@ def gemv_4bit( if state is None: raise ValueError(f'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )') + if A.numel() != A.shape[-1]: + raise ValueError(f'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]') + Bshape = state[1] bout = Bshape[0] absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = state @@ -1474,90 +1477,17 @@ def gemv_4bit( if out is None: if len(A.shape) == 3: - out = torch.zeros(size=(A.shape[0], A.shape[1], bout), dtype=A.dtype, device=A.device) + out = torch.empty(size=(A.shape[0], A.shape[1], bout), dtype=A.dtype, device=A.device) else: - out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device) + out = torch.empty(size=(A.shape[0], bout), dtype=A.dtype, device=A.device) - - - - sA = A.shape - sB = B.shape - if transposed_A and len(sA) == 2: - sA = (sA[1], sA[0]) - elif transposed_A and len(sA) == 3: - sA = (sA[0], sA[2], sA[0]) - if transposed_B and len(sB) == 2: - sB = (sB[1], sB[0]) - elif transposed_B and len(sB) == 3: - sB = (sB[0], sB[2], sB[0]) - # this is a mess: cuBLAS expect column major, but PyTorch is row major. - # So to perform the matrix multiplication, we have to treat A, B, and C matrices - # (transpose of row major is column major) - # This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these - - # matrices in the input arguments for cuBLAS - # column major: A @ B = C: [m, k] @ [k, n] = [m, n] - # row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n] - # column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m] - if len(sB) == 2: - if B.stride()[0] == B.shape[1]: - transposed_B = False - elif B.stride()[1] == B.shape[0]: - transposed_B = True - if len(A.shape) == 2: - if A.stride()[0] == A.shape[1]: - transposed_A = False - elif A.stride()[1] == A.shape[0]: - transposed_A = True - else: - if A.stride()[1] == A.shape[2]: - transposed_A = False - elif A.stride()[2] == A.shape[1]: - transposed_A = True - - if len(sA) == 2: - n = sA[0] - ldb = A.stride()[1 if transposed_A else 0] - elif len(sA) == 3 and len(sB) == 2: - n = sA[0] * sA[1] - ldb = sA[2] - - m = sB[1] - k = sB[0] - lda = B.stride()[0] - ldc = sB[1] - elif len(sB) == 3: - # special case - assert len(sA) == 3 - if not (sA[0] == sB[0] and sA[1] == sB[1]): - raise ValueError( - f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}" - ) - - transposed_A = True - transposed_B = False - - m = sB[2] - n = sA[2] - k = sB[0] * sB[1] - - lda = n - ldb = sA[2] - ldc = m - - # B^T @ A^T = C^T - # [km, nk -> mn] - #lda = ldb = ldc = 1 - #lda = 1 - if state is not None: - m = Bshape[0] - k = Bshape[1] - lda = Bshape[0] - ldc = Bshape[0] - ldb = (ldb+1)//2 - #print(m, n, k, lda, ldb, ldc) - is_on_gpu([B, A, out]) + n = 1 + m = Bshape[0] + k = Bshape[1] + lda = Bshape[0] + ldc = Bshape[0] + ldb = (A.shape[-1]+1)//2 + is_on_gpu([B, A, out, absmax, state[-1]]) m = ct.c_int32(m) n = ct.c_int32(n) k = ct.c_int32(k) @@ -1570,6 +1500,8 @@ def gemv_4bit( lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) elif A.dtype == torch.bfloat16: lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) + elif A.dtype == torch.float32: + lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) else: raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') else: diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 1aaeb22..4b05672 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3520,7 +3520,7 @@ template __global__ void kgemm_4bit_inference(int M, i } #define num_values_4bit 32 -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) { // per threadblock: @@ -3528,7 +3528,6 @@ template __global__ void kgemm_4bit_inference_naive(in // 64 packed 8bit values per warp AND each warp loads one output for B -> 1x128 * 128x1 // 4 warps -> 4 loads per iter // 1x128 * 128x4 -> 1x4 outputs - //typedef cub::WarpReduce WarpReduce; typedef cub::WarpReduce WarpReduce; __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32]; @@ -3536,7 +3535,6 @@ template __global__ void kgemm_4bit_inference_naive(in const int warp_lane = threadIdx.x % 32; const int row_B = (THREADS/32)*blockIdx.x + warp_idx; const int num_values_8bit = num_values_4bit/2; - //T local_C = T(0.0f); float local_C = 0.0f; unsigned char local_B_4bit[num_values_8bit]; @@ -3585,10 +3583,24 @@ template __global__ void kgemm_4bit_inference_naive(in if(inner_idx+num_values_4bit) { - reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_8bit/2) + 0]; - reinterpret_cast(local_A)[1] = reinterpret_cast(A)[inner_idx/(num_values_8bit/2) + 1]; - reinterpret_cast(local_A)[2] = reinterpret_cast(A)[inner_idx/(num_values_8bit/2) + 2]; - reinterpret_cast(local_A)[3] = reinterpret_cast(A)[inner_idx/(num_values_8bit/2) + 3]; + if(BITS==16) + { + reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/4) + 0]; + reinterpret_cast(local_A)[1] = reinterpret_cast(A)[inner_idx/(num_values_4bit/4) + 1]; + reinterpret_cast(local_A)[2] = reinterpret_cast(A)[inner_idx/(num_values_4bit/4) + 2]; + reinterpret_cast(local_A)[3] = reinterpret_cast(A)[inner_idx/(num_values_4bit/4) + 3]; + } + else + { + reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + 0]; + reinterpret_cast(local_A)[1] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + 1]; + reinterpret_cast(local_A)[2] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + 2]; + reinterpret_cast(local_A)[3] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + 3]; + reinterpret_cast(local_A)[4] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + 4]; + reinterpret_cast(local_A)[5] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + 5]; + reinterpret_cast(local_A)[6] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + 6]; + reinterpret_cast(local_A)[7] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + 7]; + } } else @@ -3776,8 +3788,9 @@ template __global__ void kgemm_4bit_inference(int M, int N, int K, ha template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128, 16>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index d5349d6..a7fe3d7 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -125,7 +125,7 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kfunc(T *A, T *B, T value, long n); diff --git a/csrc/ops.cu b/csrc/ops.cu index 8bcee2c..b524e0e 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -729,12 +729,12 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } -template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) { int num_blocks = (m+3)/4; - kgemm_4bit_inference_naive<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); + kgemm_4bit_inference_naive<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } template void func(T *A, T *B, T value, long n) @@ -757,8 +757,10 @@ template void func(float *A, float *B, float value, long n); template void func(float *A, float *B, float value, long n); template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive<__nv_bfloat16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); + //template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 699ff20..f37b3b3 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -200,7 +200,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); template void func(T *A, T *B, T value, long n); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index b1a079f..0aa82fe 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -29,10 +29,13 @@ void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, floa { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize) -{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } +{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize) -{ gemm_4bit_inference_naive<__nv_bfloat16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } +{ gemm_4bit_inference_naive<__nv_bfloat16, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } + +void gemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize) +{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } #define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func(A, B, value, n); } \ @@ -400,6 +403,9 @@ extern "C" void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize) { gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } + void cgemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize) + { gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } + #endif void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } diff --git a/tests/test_functional.py b/tests/test_functional.py index 6dff784..34552cb 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1776,7 +1776,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): print("partial matmul", time.time() - t0) -batch_size = 5 +batch_size = 1 seqdim = 1 values = [] #values.append((batch_size, seqdim, 768, 4 * 768)) @@ -1793,7 +1793,7 @@ values.append((batch_size, seqdim, 6656, 4*6656)) names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values] @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) def test_bench_matmul(batch, seq, model, hidden): - iters = 80 + iters = 1000 formatB = F.get_special_format_str() A = torch.randn(batch, seq, model, device="cuda").half() @@ -2361,9 +2361,7 @@ def test_normal_map_tree(): @pytest.mark.parametrize("double_quant", [True, False], ids=['DQ_True', 'DQ_False']) @pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4']) -#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=['fp16', 'bf16']) -#@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16']) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32']) def test_gemv_4bit(dtype, storage_type, double_quant): print('') for dim in [128, 256, 512, 1024, 2048, 4096]: