Fixed bank conflicts in non-vector load 422.

This commit is contained in:
Tim Dettmers 2023-04-30 18:28:52 -07:00
parent c35ed09b66
commit e01d4e033d

View File

@ -3060,11 +3060,11 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
T local_B[64/BITS]; T local_B[64/BITS];
T local_C[8]; T local_C[8];
const int a_tile_offset = 32*16; const int a_tile_offset = 32*16 + 16;
const int b_tile_offset = 16*8; const int b_tile_offset = 16*8 + 16;
__shared__ T smem_A[WARPS*32*16*2]; __shared__ T smem_A[WARPS*32*16*2 + (16*1)];
__shared__ T smem_B[WARPS*16*8*2]; __shared__ T smem_B[WARPS*16*8*2 + (16*1)];
__shared__ T smem_C[WARPS*32*8]; __shared__ T smem_C[WARPS*32*8];
wmma::fragment<wmma::matrix_a, 32, 8, 16, half, wmma::row_major> a_frag; wmma::fragment<wmma::matrix_a, 32, 8, 16, half, wmma::row_major> a_frag;
@ -3114,8 +3114,8 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
wmma::load_matrix_sync(a_frag, &(smem_A[0]), 16); // 111 mu wmma::load_matrix_sync(a_frag, &(smem_A[0]), 16); // 111 mu
wmma::load_matrix_sync(b_frag, &(smem_B[0]), 16); // 35 mu wmma::load_matrix_sync(b_frag, &(smem_B[0]), 16); // 35 mu
wmma::load_matrix_sync(a2_frag, &(smem_A[32*16]), 16); // 111 mu wmma::load_matrix_sync(a2_frag, &(smem_A[a_tile_offset]), 16); // 111 mu
wmma::load_matrix_sync(b2_frag, &(smem_B[16*8]), 16); // 35 mu wmma::load_matrix_sync(b2_frag, &(smem_B[b_tile_offset]), 16); // 35 mu
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
wmma::mma_sync(c_frag, a2_frag, b2_frag, c_frag); wmma::mma_sync(c_frag, a2_frag, b2_frag, c_frag);
} }