From 77f15fdce9f11324f6616e4fccc03d16f61347e6 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 2 May 2023 11:38:11 -0700 Subject: [PATCH] Shared memory efficient 240. --- csrc/kernels.cu | 80 ++++++++++------------------------------ csrc/ops.cu | 2 +- tests/test_functional.py | 4 +- 3 files changed, 22 insertions(+), 64 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index a528d16..8b5544a 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3041,7 +3041,7 @@ template __device__ inline void vector_l } } -#define WARPS 6 +#define WARPS 5 template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) { @@ -3061,23 +3061,18 @@ template __global__ void gemm_device(int M, T local_A[1]; T local_B[32]; - const int a_tile_offset = (16 + 16); + const int a_tile_offset = 16; const int b_tile_offset = (16*32 + 16); - __shared__ T smem_A[8*16 + (4*16*(batch_size_warps-1))]; + __shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))]; __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; - __shared__ T smem_C[8*32]; + //__shared__ T smem_C[8*32]; wmma::fragment a_frag; wmma::fragment b_frag; wmma::fragment c_frag; - wmma::fill_fragment(c_frag, 0.0f); - for(int i = threadIdx.x; i < 8*32; i+=blockDim.x) - smem_C[i] = T(0); - __syncthreads(); - int ticktock = 0; int idx = 0 + threadIdx.x; // prefetch @@ -3155,63 +3150,24 @@ template __global__ void gemm_device(int M, } __syncthreads(); + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; + ticktock = ticktock == 0 ? 1 : 0; - if(warp_id == (WARPS-1)) - for(int k = 0; k < batch_size_warps; k++) - { - wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu - wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu - wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - } - __syncthreads(); + for(int k = 0; k < batch_size_warps; k++) + { + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } // 129 mu if(warp_id == (WARPS-1)) - wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major); - __syncthreads(); + wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); - - //if(threadIdx.x >= 16){ return; } - //printf("%i %f\n", threadIdx.x, (float)smem_C[threadIdx.x]); - - //if(threadIdx.x < 32) - //if(half_warp_lane < 8 && half_warp_id > 0) - // //local_C[warp_lane] = smem_C[warp_lane + (warp_id*32*8)]; - // atomicAdd(&(smem_C[half_warp_lane]), smem_C[half_warp_lane + (half_warp_id*c_tile_offset)]); - //__syncthreads(); - - //local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum()); - //if(threadIdx.x == 0) - // for(int row = 0; row < 32; row++) - // { - // printf("row %i ", row); - // for(int id = 0; id < 4; id++) - // { - // printf(" id %i: ", id); - // for(int k = 0; k < 8; k++) - // printf("%f ", (float)smem_C[k + (row*8) + (id*32*8)]); - // printf("\n"); - // } - // } - - //__syncthreads(); - - //if((float)local_C[0] !=0.0f) - // printf("%i %i %f\n", warp_lane, warp_id, (float)local_C[0]); - //local_C[0] = WarpReduce(temp_storage).Sum(local_C[0]); - - //__syncwarp(); - - ////for(int i = threadIdx.x; i < 32*8; i+=blockDim.x) - ////{ - // if((float)local_C[0] !=0.0f) - // printf("%i %f\n", 0, (float)local_C[0]); - //} - - //if(threadIdx.x < 8 && col_offset + threadIdx.x < M) - //out[col_offset + threadIdx.x ] = smem_C[threadIdx.x]; - if(threadIdx.x < 32 && col_offset + threadIdx.x < M) - out[col_offset + threadIdx.x] = smem_C[threadIdx.x]; + if(col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_A[warp_lane]; } template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) @@ -3496,6 +3452,7 @@ __global__ void with_staging_unified(float const* global_in, float * global_out, //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); @@ -3506,6 +3463,7 @@ template __global__ void gemm_device(int M, int N, int K, half * _ //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); diff --git a/csrc/ops.cu b/csrc/ops.cu index 6bf1e89..16d82f9 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -693,7 +693,7 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); if(bits == 16) //gemm_device<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - gemm_device<<< num_blocks, 192, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + gemm_device<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); diff --git a/tests/test_functional.py b/tests/test_functional.py index 4c86d83..62dd1cb 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2358,9 +2358,9 @@ def test_normal_map_tree(): #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) def test_cutlass3_gemm(dtype): - #for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: + for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: #for dim in [4096, 5120, 6656, 8192]: - for dim in [4096]: + #for dim in [4096]: errs = [] relerrs = [] max_err = 0