From e01d4e033df8f94b28ae4e38608c621653673338 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 30 Apr 2023 18:28:52 -0700 Subject: [PATCH] Fixed bank conflicts in non-vector load 422. --- csrc/kernels.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 5d1982d..dffd40c 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3060,11 +3060,11 @@ template __global__ void gemm_device(int M, T local_B[64/BITS]; T local_C[8]; - const int a_tile_offset = 32*16; - const int b_tile_offset = 16*8; + const int a_tile_offset = 32*16 + 16; + const int b_tile_offset = 16*8 + 16; - __shared__ T smem_A[WARPS*32*16*2]; - __shared__ T smem_B[WARPS*16*8*2]; + __shared__ T smem_A[WARPS*32*16*2 + (16*1)]; + __shared__ T smem_B[WARPS*16*8*2 + (16*1)]; __shared__ T smem_C[WARPS*32*8]; wmma::fragment a_frag; @@ -3114,8 +3114,8 @@ template __global__ void gemm_device(int M, wmma::load_matrix_sync(a_frag, &(smem_A[0]), 16); // 111 mu wmma::load_matrix_sync(b_frag, &(smem_B[0]), 16); // 35 mu - wmma::load_matrix_sync(a2_frag, &(smem_A[32*16]), 16); // 111 mu - wmma::load_matrix_sync(b2_frag, &(smem_B[16*8]), 16); // 35 mu + wmma::load_matrix_sync(a2_frag, &(smem_A[a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b2_frag, &(smem_B[b_tile_offset]), 16); // 35 mu wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); wmma::mma_sync(c_frag, a2_frag, b2_frag, c_frag); }