diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 8b5544a..65ed19e 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3058,8 +3058,8 @@ template __global__ void gemm_device(int M, const int half_warp_lane = threadIdx.x % 16; const int batch_size_warps = (WARPS-1)*2; - T local_A[1]; - T local_B[32]; + T local_A[2]; + T local_B[64]; const int a_tile_offset = 16; const int b_tile_offset = (16*32 + 16); @@ -3075,14 +3075,32 @@ template __global__ void gemm_device(int M, int ticktock = 0; int idx = 0 + threadIdx.x; + int loaded_values = 0; // prefetch if(idx < K && warp_id < (WARPS-1)) { - local_A[0] = A[idx]; + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = B[(col_offset+col)*ldb+idx]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = B[(col_offset+col)*ldb+idx]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+blockDim.x-32]; + } + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+32]; + } smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; @@ -3113,11 +3131,35 @@ template __global__ void gemm_device(int M, __syncthreads(); if(idx < K && warp_id < (WARPS-1)) { - local_A[0] = A[idx]; + //local_A[0] = A[idx]; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = B[(col_offset+col)*ldb+idx]; + //#pragma unroll 32 + //for(int col = 0; col < 32; col++) + // local_B[col] = B[(col_offset+col)*ldb+idx]; + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = B[(col_offset+col)*ldb+idx]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+blockDim.x-32]; + } + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+32]; + + + } smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; diff --git a/tests/test_functional.py b/tests/test_functional.py index 62dd1cb..e9a67f5 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2376,8 +2376,8 @@ def test_cutlass3_gemm(dtype): #print('') #print(A) #print(B.t()) - #A[:, :-3] = 0 - #B[:, :-3] = 0 + #A[:, :-1] = 0 + #B[:, :-1] = 0 C1 = torch.matmul(A, B.t()) @@ -2399,7 +2399,7 @@ def test_cutlass3_gemm(dtype): #if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: # print('') - # print(i, err, mag.item(), relerr.item()) + # print(i, err, relerr) # print(A.flatten()[-6:]) # print(B.flatten()[-6:]) # out = A.flatten()[-6:]*B.flatten()[-6:] @@ -2412,7 +2412,7 @@ def test_cutlass3_gemm(dtype): c = int(C1.numel()*0.0014*(dim/256))+1 - c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False) + c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=True) #print(c/math.sqrt(dim)) print('') print(dim, sum(errs)/len(errs)/math.sqrt(dim))