From 30d03e0254f9868f29392f318787667d5bdff891 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 30 Apr 2023 18:55:12 -0700 Subject: [PATCH] 64 threads, high smem, 434. --- csrc/kernels.cu | 48 ++++++++++++++++++++++++------------------------ csrc/ops.cu | 3 ++- 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index dffd40c..4002117 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3041,7 +3041,7 @@ template __device__ inline void vector_l } } -#define WARPS 1 +#define WARPS 2 template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) { @@ -3062,10 +3062,11 @@ template __global__ void gemm_device(int M, const int a_tile_offset = 32*16 + 16; const int b_tile_offset = 16*8 + 16; + const int c_tile_offset = 32*8 + 24; - __shared__ T smem_A[WARPS*32*16*2 + (16*1)]; - __shared__ T smem_B[WARPS*16*8*2 + (16*1)]; - __shared__ T smem_C[WARPS*32*8]; + __shared__ T smem_A[WARPS*32*16*2 + (16*(WARPS-1))]; + __shared__ T smem_B[WARPS*16*8*2 + (16*(WARPS-1))]; + __shared__ T smem_C[WARPS*32*8 + (24*(WARPS-1))]; wmma::fragment a_frag; wmma::fragment b_frag; @@ -3092,46 +3093,45 @@ template __global__ void gemm_device(int M, //int block_idx = 0; //for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x) - for(int base_idx = 0; base_idx < K; base_idx+=32) + for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x) { int idx = base_idx + threadIdx.x; - if(idx >= K) - { - smem_A[threadIdx.x] = 0.0f; - //smem_B[threadIdx.x] = 0.0f; - } - else - { + if(idx >= K) + { + smem_A[threadIdx.x] = 0.0f; + //smem_B[threadIdx.x] = 0.0f; + } + else + { + smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = A[idx]; - smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = A[idx]; - - for(int col = 0; col < 8; col++) - smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = B[(col_offset+col)*ldb+idx]; - } + for(int col = 0; col < 8; col++) + smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = B[(col_offset+col)*ldb+idx]; + } __syncthreads(); wmma::load_matrix_sync(a_frag, &(smem_A[0]), 16); // 111 mu wmma::load_matrix_sync(b_frag, &(smem_B[0]), 16); // 35 mu - wmma::load_matrix_sync(a2_frag, &(smem_A[a_tile_offset]), 16); // 111 mu - wmma::load_matrix_sync(b2_frag, &(smem_B[b_tile_offset]), 16); // 35 mu + wmma::load_matrix_sync(a2_frag, &(smem_A[half_warp_id*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b2_frag, &(smem_B[half_warp_id*b_tile_offset]), 16); // 35 mu wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); wmma::mma_sync(c_frag, a2_frag, b2_frag, c_frag); } // 129 mu - wmma::store_matrix_sync(&(smem_C[0]), c_frag, 8, wmma::mem_row_major); + wmma::store_matrix_sync(&(smem_C[half_warp_id*c_tile_offset]), c_frag, 8, wmma::mem_row_major); __syncthreads(); //if(threadIdx.x >= 16){ return; } //printf("%i %f\n", threadIdx.x, (float)smem_C[threadIdx.x]); //if(threadIdx.x < 32) - //if(warp_lane < 8 && warp_id > 0) - // //local_C[warp_lane] = smem_C[warp_lane + (warp_id*32*8)]; - // atomicAdd(&(smem_C[warp_lane]), smem_C[warp_lane + (warp_id*32*8)]); - //__syncthreads(); + if(half_warp_lane < 8 && half_warp_id > 0) + //local_C[warp_lane] = smem_C[warp_lane + (warp_id*32*8)]; + atomicAdd(&(smem_C[half_warp_lane]), smem_C[half_warp_lane + (half_warp_id*c_tile_offset)]); + __syncthreads(); //local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum()); //if(threadIdx.x == 0) diff --git a/csrc/ops.cu b/csrc/ops.cu index 5c4f9c0..57d5cca 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -693,7 +693,8 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); if(bits == 16) //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); } template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)