From 264a948539d219e6b9a8fc8b9d92120d76b8878b Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 2 May 2023 16:15:38 -0700 Subject: [PATCH] 4-bit draft; 128 vector load 240. --- bitsandbytes/functional.py | 6 +- csrc/kernels.cu | 307 ++++++++++++++++++++++++------------- csrc/ops.cu | 18 +-- tests/test_functional.py | 95 ++++++++---- 4 files changed, 284 insertions(+), 142 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index b4cbd28..e5b1bf7 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1385,10 +1385,12 @@ def cutlass3_gemm( #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) if state is None: Bshape = B.shape + bout = Bshape[1] else: Bshape = state[1] + bout = Bshape[0] if out is None: - out = torch.zeros(size=(A.shape[0], Bshape[1]), dtype=A.dtype, device=A.device) + out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device) sA = A.shape sB = B.shape @@ -1464,7 +1466,7 @@ def cutlass3_gemm( if state is not None: m = Bshape[0] k = Bshape[1] - lda = Bshape[1] + lda = Bshape[0] ldc = Bshape[0] ldb = (ldb+1)//2 #print(m, n, k, lda, ldb, ldc) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 65ed19e..2373b91 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3044,22 +3044,15 @@ template __device__ inline void vector_l #define WARPS 5 template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) { - - typedef cub::WarpReduce WarpReduce; - // Allocate WarpReduce shared memory for one warp - //__shared__ typename WarpReduce::TempStorage temp_storage; - - //typedef cub::BlockReduce BlockReduce; - //// Allocate shared memory for BlockReduce - //__shared__ typename BlockReduce::TempStorage reduce; int col_offset = blockIdx.x *32; const int warp_id = threadIdx.x / 32; const int half_warp_id = threadIdx.x / 16; const int half_warp_lane = threadIdx.x % 16; const int batch_size_warps = (WARPS-1)*2; + const int val_per_iter = blockDim.x-32; - T local_A[2]; - T local_B[64]; + T local_A[4]; + T local_B[128]; const int a_tile_offset = 16; const int b_tile_offset = (16*32 + 16); @@ -3082,24 +3075,45 @@ template __global__ void gemm_device(int M, if(loaded_values == 0) { local_A[0] = A[idx]; - local_A[1] = A[idx+blockDim.x-32]; + local_A[1] = A[idx+(1*val_per_iter)]; + local_A[2] = A[idx+(2*val_per_iter)]; + local_A[3] = A[idx+(3*val_per_iter)]; #pragma unroll 32 for(int col = 0; col < 32; col++) { local_B[col] = B[(col_offset+col)*ldb+idx]; - local_B[col+32] = B[(col_offset+col)*ldb+idx+blockDim.x-32]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; } - loaded_values = 1; + loaded_values = 3; } else { - local_A[0] = local_A[1]; - loaded_values--; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = local_B[col+32]; + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } + loaded_values--; } smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; @@ -3139,26 +3153,46 @@ template __global__ void gemm_device(int M, if(loaded_values == 0) { local_A[0] = A[idx]; - local_A[1] = A[idx+blockDim.x-32]; + local_A[1] = A[idx+(1*val_per_iter)]; + local_A[2] = A[idx+(2*val_per_iter)]; + local_A[3] = A[idx+(3*val_per_iter)]; #pragma unroll 32 for(int col = 0; col < 32; col++) { local_B[col] = B[(col_offset+col)*ldb+idx]; - local_B[col+32] = B[(col_offset+col)*ldb+idx+blockDim.x-32]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; } - loaded_values = 1; + loaded_values = 3; + } else { - local_A[0] = local_A[1]; + + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } loaded_values--; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = local_B[col+32]; - - } smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; @@ -3215,104 +3249,166 @@ template __global__ void gemm_device(int M, template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) { - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage reduce; - int col_offset = blockIdx.x *8; + int col_offset = blockIdx.x *32; + const int warp_id = threadIdx.x / 32; + const int half_warp_id = threadIdx.x / 16; + const int half_warp_lane = threadIdx.x % 16; + const int batch_size_warps = (WARPS-1)*2; - T local_A[32]; - unsigned char local_B_4bit[16]; - T local_B[32]; - T local_C[8]; + T local_A[2]; + T local_B[64]; + unsigned char local_B_4bit[32]; - __shared__ T smem_C[8]; + const int a_tile_offset = 16; + const int b_tile_offset = (16*32 + 16); - if(threadIdx.x < 8) - smem_C[threadIdx.x] = T(0); - __syncthreads(); + __shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; + //__shared__ T smem_C[8*32]; - #pragma unroll 8 - for(int k = 0; k < 8; k++) - local_C[k] = T(0); + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; + wmma::fill_fragment(c_frag, 0.0f); - - for(int idx = threadIdx.x*32; idx < K; idx+=blockDim.x*32) + int ticktock = 0; + int idx = 0 + threadIdx.x; + int loaded_values = 0; + // prefetch + if(idx < K && warp_id < (WARPS-1)) { - - // we load only 8 values per iteration from A, so we - // need to do 4 loads for every single load from B - // for B, we have packed values, so the 16 8-bit values - // turn into 32 4-bit values to 4x 4 loads turns into 4x 8 loads - vector_load(local_A, A, idx, idx, K); - vector_load(&(local_A[8]), A, idx+8, idx+8, K); - vector_load(&(local_A[16]), A, idx+16, idx+16, K); - vector_load(&(local_A[24]), A, idx+24, idx+24, K); - - for(int col = 0; col < 8; col++) + if(loaded_values == 0) { - if((col + col_offset) >= M){ break; } - - int offset_B = (col_offset+col)*ldb; - // 0111 -> 0.0f in NF4 - // since we have packed 8-bits, we need cat(0b0111, 0b0111) = 0b01110111 - vector_load(local_B_4bit, B, (offset_B+idx+1)/2, (idx+1)/2, (K+1)/2, 0b01110111); - - int absidx = (idx + offset_B)/blocksize; - half local_absmax = __ldg(&(absmax[absidx])); - //for(int k = 0; k < 16; k++) - //printf("%i %i ", local_B_4bit[k] >> 4, local_B_4bit[k] & 0x0F); - //printf("\n"); - - //vector_load(local_A, A, idx, idx, K); - - #pragma unroll 16 - for(int k = 0; k < 16; k++) - { - - //if(local_B_4bit[k ] != 0b01110111) - //printf("(%i %i %i) %i -> %f, %i -> %f\n", threadIdx.x , k, K, local_B_4bit[k ] >> 4, dDequantizeNF4(local_B_4bit[k ] >> 4, local_absmax), - //local_B_4bit[k ] & 0x0F, dDequantizeNF4(local_B_4bit[k ] & 0x0F, local_absmax)); - //local_B[k*2] = d2DequantizeFP4(local_B_4bit[k] >> 4);//*local_absmax; - //local_B[k*2 + 1] = d2DequantizeFP4(local_B_4bit[k] & 0x0F);//*local_absmax; - local_B[k*2] = (half)(local_B_4bit[k] >> 4)*local_absmax; - local_B[k*2 + 1] = (half)(local_B_4bit[k] & 0x0F)*local_absmax; - //local_B[k*2] = (half)dDequantizeNF4(local_B_4bit[k ] >> 4);//*local_absmax; - //local_B[k*2 + 1] = (half)dDequantizeNF4(local_B_4bit[k ] & 0x0F);//*local_absmax; - } + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; #pragma unroll 32 - //for(int k = 0; k < 8; k++) - for(int k = 0; k < 32; k++) + for(int col = 0; col < 32; col++) + local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; + + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) { - local_C[col] += local_A[k]*local_B[k]; - //if((float)local_A[k] != 0.0 && (float)local_B[k] != 0.0) - //if((float)local_B[k] != 0.0) - //printf("%i %i %i %i %f*%f\n", threadIdx.x, k, col, (float)local_A[k], (float)local_B[k]); + local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f); + local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f); } } - } - #pragma unroll 8 - for(int k = 0; k < 8; k++) + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) { - local_C[k] = BlockReduce(reduce).Reduce(local_C[k], cub::Sum()); + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + { + idx = base_idx + threadIdx.x; + __syncthreads(); + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; + local_B_4bit[col+16] = B[(col_offset+col)*ldb+idx]; + } + + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + int absidx = (idx + col_offset)/blocksize; + half local_absmax = __ldg(&(absmax[absidx])); + + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) + { + local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); + local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); + } + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) + { + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } } - if(threadIdx.x == 0) + __syncthreads(); + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; + + ticktock = ticktock == 0 ? 1 : 0; + for(int k = 0; k < batch_size_warps; k++) { - #pragma unroll 8 - for(int k = 0; k < 8; k++) - smem_C[k] = local_C[k]; + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); } - else if(threadIdx.x >= 32) - // early return for unused warps - return; - __syncwarp(); + // 129 mu + if(warp_id == (WARPS-1)) + wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); - - if(threadIdx.x < 8 && col_offset + threadIdx.x < M) - out[col_offset + threadIdx.x ] = smem_C[threadIdx.x]; + if(col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_A[warp_lane]; } //#define ROWS 2 @@ -3513,6 +3609,7 @@ template __global__ void gemm_device(int M, int N, int K, half * _ template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); //template __global__ void kMatmul_inference_4bit(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB); diff --git a/csrc/ops.cu b/csrc/ops.cu index 16d82f9..4d68436 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -703,17 +703,17 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) { - int num_blocks = (m+7)/8; + int num_blocks = (m+31)/32; - cout << num_blocks << endl; - cout << lda << endl; - cout << ldb << endl; - cout << ldc << endl; + //cout << num_blocks << endl; + //cout << lda << endl; + //cout << ldb << endl; + //cout << ldc << endl; - cout << m << endl; - cout << n << endl; - cout << k << endl; - kgemm_4bit_inference<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //cout << m << endl; + //cout << n << endl; + //cout << k << endl; + kgemm_4bit_inference<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } diff --git a/tests/test_functional.py b/tests/test_functional.py index e9a67f5..dc4e40d 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2358,20 +2358,19 @@ def test_normal_map_tree(): #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) def test_cutlass3_gemm(dtype): - for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: + debug = True + #for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: #for dim in [4096, 5120, 6656, 8192]: - #for dim in [4096]: + for dim in [4096]: + #for dim in [128+1]: errs = [] relerrs = [] max_err = 0 max_relerr = 0 for i in range(100): - #A = torch.rand(2, 4092, dtype=dtype, device='cuda') - #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') - #A = torch.rand(1, 4096, dtype=dtype, device='cuda') - #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') - A = torch.randn(1, dim+0, dtype=dtype, device='cuda') + A = torch.randn(1, dim, dtype=dtype, device='cuda') B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim) + #B = torch.randn(1, dim, dtype=dtype, device='cuda')/math.sqrt(dim) #print('') #print(A) @@ -2397,7 +2396,7 @@ def test_cutlass3_gemm(dtype): errs.append(err) relerrs.append(relerr) - #if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: + #if not debug and err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: # print('') # print(i, err, relerr) # print(A.flatten()[-6:]) @@ -2412,7 +2411,7 @@ def test_cutlass3_gemm(dtype): c = int(C1.numel()*0.0014*(dim/256))+1 - c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=True) + c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=not debug) #print(c/math.sqrt(dim)) print('') print(dim, sum(errs)/len(errs)/math.sqrt(dim)) @@ -2422,29 +2421,73 @@ def test_cutlass3_gemm(dtype): #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) def test_gemm_4bit(dtype): - for i in range(1): - #A = torch.rand(2, 4092, dtype=dtype, device='cuda') - #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') - #torch.random.manual_seed(17) - A = torch.rand(1, 4096, dtype=dtype, device='cuda') - B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') + #for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: + #for dim in [4096, 5120, 6656, 8192]: + #for dim in [32]: + for dim in [4096]: + errs = [] + relerrs = [] + max_err = 0 + max_relerr = 0 + for i in range(1): + #A = torch.rand(2, 4092, dtype=dtype, device='cuda') + #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') + #A = torch.rand(1, 4096, dtype=dtype, device='cuda') + #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') + A = torch.randn(1, dim+0, dtype=dtype, device='cuda') + B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim) - #print('') - #print(A) - #print(B) + #print('') + #print(A) + #print(B.t()) + #A[:, :-1] = 0 + #B[:, :-1] = 0 - qB, state = F.quantize_nf4(B) - F.dequantize_nf4(qB, state) + qB, state = F.quantize_nf4(B) + F.dequantize_nf4(qB, state) + C3 = torch.matmul(A, B.t()) + C2 = F.cutlass3_gemm(A, qB.t(), state=state) + C1 = bnb.matmul_4bit(A, qB.t(), state) + C2 = F.cutlass3_gemm(A, qB.t(), state=state) - C1 = torch.matmul(A, B.t()) - #C1 = bnb.matmul_4bit(A, qB.t(), state) - C2 = F.cutlass3_gemm(A, qB.t(), state=state) - #print(C1) - #print(C2) + print(C1.shape, C2.shape) - #torch.testing.assert_close(C1, C2, atol=1e-5, rtol=0.005) + # tensor cores are non-deterministic + # so we need to analyze errors around the mean + # to test our implementation + err = torch.abs(C1-C2) + mag = torch.abs(C1)+1e-8 + relerr = err/mag + max_err = max(err.max(), max_err) + max_relerr = max(relerr.max(), max_relerr) + err = err.mean().item() + relerr = relerr.mean().item() + errs.append(err) + relerrs.append(relerr) + + if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: + print('') + print(i, err, relerr) + print(A.flatten()[-6:]) + print(B.flatten()[-6:]) + out = A.flatten()[-6:]*B.flatten()[-6:] + print(out) + print(out[:-1].sum()) + print('='*80) + print(C1.flatten()[-6:]) + print(C2.flatten()[-6:]) + #assert False, 'ERROR' + + c = int(C1.numel()*0.0014*(dim/256))+1 + + c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False) + #print(c/math.sqrt(dim)) + print('') + print(dim, sum(errs)/len(errs)/math.sqrt(dim)) + print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim)) + print(dim, (max_err.item(), max_relerr.item())) def test_pipeline_func(): a = torch.rand(2, 4).cuda()