diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 4002117..301221c 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3053,25 +3053,23 @@ template __global__ void gemm_device(int M, //// Allocate shared memory for BlockReduce //__shared__ typename BlockReduce::TempStorage reduce; int col_offset = blockIdx.x *8; + const int warp_id = threadIdx.x / 32; const int half_warp_id = threadIdx.x / 16; const int half_warp_lane = threadIdx.x % 16; - T local_A[64/BITS]; - T local_B[64/BITS]; - T local_C[8]; + T local_A[1]; + T local_B[8]; const int a_tile_offset = 32*16 + 16; const int b_tile_offset = 16*8 + 16; const int c_tile_offset = 32*8 + 24; - __shared__ T smem_A[WARPS*32*16*2 + (16*(WARPS-1))]; - __shared__ T smem_B[WARPS*16*8*2 + (16*(WARPS-1))]; + __shared__ T smem_A[WARPS*32*16 + (16*(WARPS-1))]; + __shared__ T smem_B[WARPS*16*8 + (16*(WARPS-1))]; __shared__ T smem_C[WARPS*32*8 + (24*(WARPS-1))]; wmma::fragment a_frag; wmma::fragment b_frag; - wmma::fragment a2_frag; - wmma::fragment b2_frag; wmma::fragment c_frag; wmma::fill_fragment(c_frag, 0.0f); @@ -3087,9 +3085,9 @@ template __global__ void gemm_device(int M, smem_C[i] = T(0); __syncthreads(); - #pragma unroll 8 - for(int k = 0; k < 8; k++) - local_C[k] = T(0); + //#pragma unroll 8 + //for(int k = 0; k < 8; k++) + //local_C[k] = T(0); //int block_idx = 0; //for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x) @@ -3097,27 +3095,45 @@ template __global__ void gemm_device(int M, { int idx = base_idx + threadIdx.x; - if(idx >= K) + for(int k = 0; k < 2; k++) { - smem_A[threadIdx.x] = 0.0f; - //smem_B[threadIdx.x] = 0.0f; + if(k == 0) + { + if(idx < K) + { + local_A[0] = A[idx]; + + #pragma unroll 8 + for(int col = 0; col < 8; col++) + local_B[col] = B[(col_offset+col)*ldb+idx]; + } + + } + + if(idx >= K) + { + smem_A[threadIdx.x] = 0.0f; + //smem_B[threadIdx.x] = 0.0f; + } + else + { + if((k == 0 && half_warp_id % 2 == 0) || + (k == 1 && half_warp_id % 2 == 1)) + { + smem_A[half_warp_lane + (warp_id*a_tile_offset)] = local_A[0]; + + #pragma unroll 8 + for(int col = 0; col < 8; col++) + smem_B[half_warp_lane + (warp_id*b_tile_offset) + (col*16)] = local_B[col]; + } + } + + __syncthreads(); + + wmma::load_matrix_sync(a_frag, &(smem_A[warp_id*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[warp_id*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); } - else - { - smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = A[idx]; - - for(int col = 0; col < 8; col++) - smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = B[(col_offset+col)*ldb+idx]; - } - - __syncthreads(); - - wmma::load_matrix_sync(a_frag, &(smem_A[0]), 16); // 111 mu - wmma::load_matrix_sync(b_frag, &(smem_B[0]), 16); // 35 mu - wmma::load_matrix_sync(a2_frag, &(smem_A[half_warp_id*a_tile_offset]), 16); // 111 mu - wmma::load_matrix_sync(b2_frag, &(smem_B[half_warp_id*b_tile_offset]), 16); // 35 mu - wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - wmma::mma_sync(c_frag, a2_frag, b2_frag, c_frag); } // 129 mu