diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index f725c1c..b4cbd28 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1467,7 +1467,7 @@ def cutlass3_gemm( lda = Bshape[1] ldc = Bshape[0] ldb = (ldb+1)//2 - print(m, n, k, lda, ldb, ldc) + #print(m, n, k, lda, ldb, ldc) is_on_gpu([B, A, out]) m = ct.c_int32(m) n = ct.c_int32(n) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index b03c6ca..477904c 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3061,9 +3061,8 @@ template __global__ void gemm_device(int M, T local_A[1]; T local_B[32]; - const int a_tile_offset = (8*16 + 16); - const int b_tile_offset = (16*32 + 16); - const int c_tile_offset = 8*32 + 24; + const int a_tile_offset = (8*16); + const int b_tile_offset = (16*32); __shared__ T smem_A[2*batch_size_warps*8*16 + (2*16*(batch_size_warps-1))]; __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; @@ -3109,6 +3108,19 @@ template __global__ void gemm_device(int M, for(int col = 0; col < 32; col++) smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = local_B[col]; } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = T(0.0); + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = T(0.0f); + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = T(0.0f); + } ticktock = ticktock == 0 ? 1 : 0; for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x-32) @@ -3130,6 +3142,19 @@ template __global__ void gemm_device(int M, for(int col = 0; col < 32; col++) smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } ticktock = ticktock == 0 ? 1 : 0; if(warp_id == (WARPS-1)) diff --git a/csrc/ops.cu b/csrc/ops.cu index 2ccb418..6bf1e89 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -680,14 +680,14 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out int num_blocks = (m+31)/32; - cout << num_blocks << endl; - cout << lda << endl; - cout << ldb << endl; - cout << ldc << endl; + //cout << num_blocks << endl; + //cout << lda << endl; + //cout << ldb << endl; + //cout << ldc << endl; - cout << m << endl; - cout << n << endl; - cout << k << endl; + //cout << m << endl; + //cout << n << endl; + //cout << k << endl; //if(bits == 32) //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); diff --git a/tests/test_functional.py b/tests/test_functional.py index 5f90f69..25fbb5b 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2355,25 +2355,47 @@ def test_normal_map_tree(): #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) def test_cutlass3_gemm(dtype): - for i in range(1): + for i in range(100): #A = torch.rand(2, 4092, dtype=dtype, device='cuda') #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') #A = torch.rand(1, 4096, dtype=dtype, device='cuda') #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') - A = torch.rand(1, 4096, dtype=dtype, device='cuda') - B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') + A = torch.randn(1, 128+32, dtype=dtype, device='cuda') + B = torch.randn(4096, 128+32, dtype=dtype, device='cuda')/math.sqrt(128) #print('') #print(A) #print(B.t()) + #A[:, :-3] = 0 + #B[:, :-3] = 0 C1 = torch.matmul(A, B.t()) C2 = F.cutlass3_gemm(A, B.t()) - print(C1) - print(C2) + err = C1-C2 - torch.testing.assert_close(C1, C2, atol=1e-05, rtol=0.06) + # tensor cores are non-deterministic + # so we need to analyze errors around the mean + # to test our implementation + err = torch.abs(err.mean()).item() + mag = torch.abs(C1).mean() + relerr = err/mag + + if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: + print('') + print(i, err, mag.item(), relerr.item()) + print(A.flatten()[-6:]) + print(B.flatten()[-6:]) + out = A.flatten()[-6:]*B.flatten()[-6:] + print(out) + print(out[:-1].sum()) + print('='*80) + print(C1.flatten()[-6:]) + print(C2.flatten()[-6:]) + #assert False, 'ERROR' + + c = int(C1.numel()*0.001) + assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c) #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])