From 9192c9de648338dd9281368ed0bff20dc123490b Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 2 May 2023 07:50:32 -0700 Subject: [PATCH] Tighter and scaled error analysis. --- csrc/kernels.cu | 15 ++++++- tests/test_functional.py | 85 +++++++++++++++++++++++----------------- 2 files changed, 64 insertions(+), 36 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 477904c..2fa288f 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3123,6 +3123,7 @@ template __global__ void gemm_device(int M, } ticktock = ticktock == 0 ? 1 : 0; + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x-32) { idx = base_idx + threadIdx.x; @@ -3155,8 +3156,9 @@ template __global__ void gemm_device(int M, for(int col = 0; col < 32; col++) smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; } - ticktock = ticktock == 0 ? 1 : 0; + //ticktock = ticktock == 0 ? 1 : 0; + __syncthreads(); if(warp_id == (WARPS-1)) for(int k = 0; k < batch_size_warps; k++) { @@ -3166,11 +3168,22 @@ template __global__ void gemm_device(int M, } } + //__syncthreads(); + //if(warp_id == (WARPS-1)) + // for(int k = 0; k < batch_size_warps; k++) + // { + // wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + // wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + // wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + // } + __syncthreads(); + // 129 mu if(warp_id == (WARPS-1)) wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major); __syncthreads(); + //if(threadIdx.x >= 16){ return; } //printf("%i %f\n", threadIdx.x, (float)smem_C[threadIdx.x]); diff --git a/tests/test_functional.py b/tests/test_functional.py index 25fbb5b..0500984 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2355,47 +2355,62 @@ def test_normal_map_tree(): #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) def test_cutlass3_gemm(dtype): - for i in range(100): - #A = torch.rand(2, 4092, dtype=dtype, device='cuda') - #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') - #A = torch.rand(1, 4096, dtype=dtype, device='cuda') - #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') - A = torch.randn(1, 128+32, dtype=dtype, device='cuda') - B = torch.randn(4096, 128+32, dtype=dtype, device='cuda')/math.sqrt(128) + for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: + errs = [] + relerrs = [] + max_err = 0 + max_relerr = 0 + for i in range(100): + #A = torch.rand(2, 4092, dtype=dtype, device='cuda') + #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') + #A = torch.rand(1, 4096, dtype=dtype, device='cuda') + #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') + A = torch.randn(1, dim+0, dtype=dtype, device='cuda') + B = torch.randn(4*496, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim) - #print('') - #print(A) - #print(B.t()) - #A[:, :-3] = 0 - #B[:, :-3] = 0 + #print('') + #print(A) + #print(B.t()) + #A[:, :-3] = 0 + #B[:, :-3] = 0 - C1 = torch.matmul(A, B.t()) - C2 = F.cutlass3_gemm(A, B.t()) - err = C1-C2 + C1 = torch.matmul(A, B.t()) + C2 = F.cutlass3_gemm(A, B.t()) - # tensor cores are non-deterministic - # so we need to analyze errors around the mean - # to test our implementation - err = torch.abs(err.mean()).item() - mag = torch.abs(C1).mean() - relerr = err/mag + # tensor cores are non-deterministic + # so we need to analyze errors around the mean + # to test our implementation + err = torch.abs(C1-C2) + mag = torch.abs(C1)+1e-8 + relerr = err/mag + max_err = max(err.max(), max_err) + max_relerr = max(relerr.max(), max_relerr) + err = err.mean().item() + relerr = relerr.mean().item() - if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: - print('') - print(i, err, mag.item(), relerr.item()) - print(A.flatten()[-6:]) - print(B.flatten()[-6:]) - out = A.flatten()[-6:]*B.flatten()[-6:] - print(out) - print(out[:-1].sum()) - print('='*80) - print(C1.flatten()[-6:]) - print(C2.flatten()[-6:]) - #assert False, 'ERROR' + errs.append(err) + relerrs.append(relerr) - c = int(C1.numel()*0.001) - assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c) + #if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: + # print('') + # print(i, err, mag.item(), relerr.item()) + # print(A.flatten()[-6:]) + # print(B.flatten()[-6:]) + # out = A.flatten()[-6:]*B.flatten()[-6:] + # print(out) + # print(out[:-1].sum()) + # print('='*80) + # print(C1.flatten()[-6:]) + # print(C2.flatten()[-6:]) + # #assert False, 'ERROR' + + c = int(C1.numel()*0.00125*(dim/256))+1 + assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c) + print('') + print(dim, sum(errs)/len(errs)/math.sqrt(dim)) + print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim)) + print(dim, (max_err.item(), max_relerr.item())) #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])