From 7bfa09d0fcaa524863bcc8ea71436f99423bbd3f Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 1 May 2023 16:38:09 -0700 Subject: [PATCH] 8x32 240 6 warps. --- csrc/kernels.cu | 50 ++++++++++++++++++++++++++----------------------- csrc/ops.cu | 6 ++++-- 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 4e3a4a3..b03c6ca 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3041,7 +3041,7 @@ template __device__ inline void vector_l } } -#define WARPS 4 +#define WARPS 6 template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) { @@ -3052,26 +3052,26 @@ template __global__ void gemm_device(int M, //typedef cub::BlockReduce BlockReduce; //// Allocate shared memory for BlockReduce //__shared__ typename BlockReduce::TempStorage reduce; - int col_offset = blockIdx.x *16; + int col_offset = blockIdx.x *32; const int warp_id = threadIdx.x / 32; const int half_warp_id = threadIdx.x / 16; const int half_warp_lane = threadIdx.x % 16; const int batch_size_warps = (WARPS-1)*2; T local_A[1]; - T local_B[16]; + T local_B[32]; - const int a_tile_offset = (16*16 + 16); - const int b_tile_offset = (16*16 + 16); - const int c_tile_offset = 16*16 + 24; + const int a_tile_offset = (8*16 + 16); + const int b_tile_offset = (16*32 + 16); + const int c_tile_offset = 8*32 + 24; - __shared__ T smem_A[2*batch_size_warps*16*16 + (2*16*(batch_size_warps-1))]; - __shared__ T smem_B[2*batch_size_warps*16*16 + (2*16*(batch_size_warps-1))]; - __shared__ T smem_C[16*16]; + __shared__ T smem_A[2*batch_size_warps*8*16 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_C[8*32]; - wmma::fragment a_frag; - wmma::fragment b_frag; - wmma::fragment c_frag; + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; wmma::fill_fragment(c_frag, 0.0f); @@ -3082,7 +3082,7 @@ template __global__ void gemm_device(int M, //for(int i = threadIdx.x; i < 16*16*WARPS; i+=blockDim.x) // smem_B[i] = T(0); - for(int i = threadIdx.x; i < 16*16; i+=blockDim.x) + for(int i = threadIdx.x; i < 8*32; i+=blockDim.x) smem_C[i] = T(0); __syncthreads(); @@ -3099,14 +3099,14 @@ template __global__ void gemm_device(int M, { local_A[0] = A[idx]; - #pragma unroll 16 - for(int col = 0; col < 16; col++) + #pragma unroll 32 + for(int col = 0; col < 32; col++) local_B[col] = B[(col_offset+col)*ldb+idx]; smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = local_A[0]; - #pragma unroll 16 - for(int col = 0; col < 16; col++) + #pragma unroll 32 + for(int col = 0; col < 32; col++) smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = local_B[col]; } ticktock = ticktock == 0 ? 1 : 0; @@ -3120,14 +3120,14 @@ template __global__ void gemm_device(int M, { local_A[0] = A[idx]; - #pragma unroll 16 - for(int col = 0; col < 16; col++) + #pragma unroll 32 + for(int col = 0; col < 32; col++) local_B[col] = B[(col_offset+col)*ldb+idx]; smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; - #pragma unroll 16 - for(int col = 0; col < 16; col++) + #pragma unroll 32 + for(int col = 0; col < 32; col++) smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; } ticktock = ticktock == 0 ? 1 : 0; @@ -3143,7 +3143,7 @@ template __global__ void gemm_device(int M, // 129 mu if(warp_id == (WARPS-1)) - wmma::store_matrix_sync(&(smem_C[0]), c_frag, 16, wmma::mem_row_major); + wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major); __syncthreads(); //if(threadIdx.x >= 16){ return; } @@ -3185,7 +3185,7 @@ template __global__ void gemm_device(int M, //if(threadIdx.x < 8 && col_offset + threadIdx.x < M) //out[col_offset + threadIdx.x ] = smem_C[threadIdx.x]; - if(threadIdx.x < 16 && col_offset + threadIdx.x < M) + if(threadIdx.x < 32 && col_offset + threadIdx.x < M) out[col_offset + threadIdx.x] = smem_C[threadIdx.x]; } @@ -3470,18 +3470,22 @@ __global__ void with_staging_unified(float const* global_in, float * global_out, // these are not used and make no sense, but the compiler needs them //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); // these are not used and make no sense, but the compiler needs them //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); diff --git a/csrc/ops.cu b/csrc/ops.cu index d0e903f..2ccb418 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -678,7 +678,7 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size) template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) { - int num_blocks = (m+15)/16; + int num_blocks = (m+31)/32; cout << num_blocks << endl; cout << lda << endl; @@ -693,7 +693,9 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); if(bits == 16) //gemm_device<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + gemm_device<<< num_blocks, 192, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); }