Halved shared memory 466.
This commit is contained in:
parent
30d03e0254
commit
cabcd9b9d5
|
@ -3053,25 +3053,23 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
|||
//// Allocate shared memory for BlockReduce
|
||||
//__shared__ typename BlockReduce::TempStorage reduce;
|
||||
int col_offset = blockIdx.x *8;
|
||||
const int warp_id = threadIdx.x / 32;
|
||||
const int half_warp_id = threadIdx.x / 16;
|
||||
const int half_warp_lane = threadIdx.x % 16;
|
||||
|
||||
T local_A[64/BITS];
|
||||
T local_B[64/BITS];
|
||||
T local_C[8];
|
||||
T local_A[1];
|
||||
T local_B[8];
|
||||
|
||||
const int a_tile_offset = 32*16 + 16;
|
||||
const int b_tile_offset = 16*8 + 16;
|
||||
const int c_tile_offset = 32*8 + 24;
|
||||
|
||||
__shared__ T smem_A[WARPS*32*16*2 + (16*(WARPS-1))];
|
||||
__shared__ T smem_B[WARPS*16*8*2 + (16*(WARPS-1))];
|
||||
__shared__ T smem_A[WARPS*32*16 + (16*(WARPS-1))];
|
||||
__shared__ T smem_B[WARPS*16*8 + (16*(WARPS-1))];
|
||||
__shared__ T smem_C[WARPS*32*8 + (24*(WARPS-1))];
|
||||
|
||||
wmma::fragment<wmma::matrix_a, 32, 8, 16, half, wmma::row_major> a_frag;
|
||||
wmma::fragment<wmma::matrix_b, 32, 8, 16, half, wmma::col_major> b_frag;
|
||||
wmma::fragment<wmma::matrix_a, 32, 8, 16, half, wmma::row_major> a2_frag;
|
||||
wmma::fragment<wmma::matrix_b, 32, 8, 16, half, wmma::col_major> b2_frag;
|
||||
wmma::fragment<wmma::accumulator, 32, 8, 16, half> c_frag;
|
||||
|
||||
wmma::fill_fragment(c_frag, 0.0f);
|
||||
|
@ -3087,9 +3085,9 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
|||
smem_C[i] = T(0);
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll 8
|
||||
for(int k = 0; k < 8; k++)
|
||||
local_C[k] = T(0);
|
||||
//#pragma unroll 8
|
||||
//for(int k = 0; k < 8; k++)
|
||||
//local_C[k] = T(0);
|
||||
|
||||
//int block_idx = 0;
|
||||
//for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x)
|
||||
|
@ -3097,27 +3095,45 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
|||
{
|
||||
int idx = base_idx + threadIdx.x;
|
||||
|
||||
if(idx >= K)
|
||||
for(int k = 0; k < 2; k++)
|
||||
{
|
||||
smem_A[threadIdx.x] = 0.0f;
|
||||
//smem_B[threadIdx.x] = 0.0f;
|
||||
if(k == 0)
|
||||
{
|
||||
if(idx < K)
|
||||
{
|
||||
local_A[0] = A[idx];
|
||||
|
||||
#pragma unroll 8
|
||||
for(int col = 0; col < 8; col++)
|
||||
local_B[col] = B[(col_offset+col)*ldb+idx];
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if(idx >= K)
|
||||
{
|
||||
smem_A[threadIdx.x] = 0.0f;
|
||||
//smem_B[threadIdx.x] = 0.0f;
|
||||
}
|
||||
else
|
||||
{
|
||||
if((k == 0 && half_warp_id % 2 == 0) ||
|
||||
(k == 1 && half_warp_id % 2 == 1))
|
||||
{
|
||||
smem_A[half_warp_lane + (warp_id*a_tile_offset)] = local_A[0];
|
||||
|
||||
#pragma unroll 8
|
||||
for(int col = 0; col < 8; col++)
|
||||
smem_B[half_warp_lane + (warp_id*b_tile_offset) + (col*16)] = local_B[col];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
wmma::load_matrix_sync(a_frag, &(smem_A[warp_id*a_tile_offset]), 16); // 111 mu
|
||||
wmma::load_matrix_sync(b_frag, &(smem_B[warp_id*b_tile_offset]), 16); // 35 mu
|
||||
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
|
||||
}
|
||||
else
|
||||
{
|
||||
smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = A[idx];
|
||||
|
||||
for(int col = 0; col < 8; col++)
|
||||
smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = B[(col_offset+col)*ldb+idx];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
wmma::load_matrix_sync(a_frag, &(smem_A[0]), 16); // 111 mu
|
||||
wmma::load_matrix_sync(b_frag, &(smem_B[0]), 16); // 35 mu
|
||||
wmma::load_matrix_sync(a2_frag, &(smem_A[half_warp_id*a_tile_offset]), 16); // 111 mu
|
||||
wmma::load_matrix_sync(b2_frag, &(smem_B[half_warp_id*b_tile_offset]), 16); // 35 mu
|
||||
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
|
||||
wmma::mma_sync(c_frag, a2_frag, b2_frag, c_frag);
|
||||
}
|
||||
|
||||
// 129 mu
|
||||
|
|
Loading…
Reference in New Issue
Block a user