Fixed bank conflicts in non-vector load 422.
This commit is contained in:
parent
c35ed09b66
commit
e01d4e033d
|
@ -3060,11 +3060,11 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
|||
T local_B[64/BITS];
|
||||
T local_C[8];
|
||||
|
||||
const int a_tile_offset = 32*16;
|
||||
const int b_tile_offset = 16*8;
|
||||
const int a_tile_offset = 32*16 + 16;
|
||||
const int b_tile_offset = 16*8 + 16;
|
||||
|
||||
__shared__ T smem_A[WARPS*32*16*2];
|
||||
__shared__ T smem_B[WARPS*16*8*2];
|
||||
__shared__ T smem_A[WARPS*32*16*2 + (16*1)];
|
||||
__shared__ T smem_B[WARPS*16*8*2 + (16*1)];
|
||||
__shared__ T smem_C[WARPS*32*8];
|
||||
|
||||
wmma::fragment<wmma::matrix_a, 32, 8, 16, half, wmma::row_major> a_frag;
|
||||
|
@ -3114,8 +3114,8 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
|||
|
||||
wmma::load_matrix_sync(a_frag, &(smem_A[0]), 16); // 111 mu
|
||||
wmma::load_matrix_sync(b_frag, &(smem_B[0]), 16); // 35 mu
|
||||
wmma::load_matrix_sync(a2_frag, &(smem_A[32*16]), 16); // 111 mu
|
||||
wmma::load_matrix_sync(b2_frag, &(smem_B[16*8]), 16); // 35 mu
|
||||
wmma::load_matrix_sync(a2_frag, &(smem_A[a_tile_offset]), 16); // 111 mu
|
||||
wmma::load_matrix_sync(b2_frag, &(smem_B[b_tile_offset]), 16); // 35 mu
|
||||
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
|
||||
wmma::mma_sync(c_frag, a2_frag, b2_frag, c_frag);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user