diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 53a183d..24b004b 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -23,6 +24,8 @@ #define NUM 4 #define NUM_BLOCK 4096 +using namespace nvcuda; + // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda __device__ float atomicMax(float* address, float val) { int* address_as_i = reinterpret_cast(address); @@ -3041,62 +3044,164 @@ template __device__ inline void vector_l template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) { - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage reduce; - int col_offset = blockIdx.x *8; + typedef cub::WarpReduce WarpReduce; + // Allocate WarpReduce shared memory for one warp + //__shared__ typename WarpReduce::TempStorage temp_storage; - T local_A[128/BITS]; - T local_B[128/BITS]; + //typedef cub::BlockReduce BlockReduce; + //// Allocate shared memory for BlockReduce + //__shared__ typename BlockReduce::TempStorage reduce; + int col_offset = blockIdx.x *8; + const int warp_id = threadIdx.x / 32; + const int warp_lane = threadIdx.x % 32; + + T local_A[64/BITS]; + T local_B[64/BITS]; T local_C[8]; - __shared__ T smem_C[8]; + __shared__ T smem_A[4*32*16]; + __shared__ T smem_B[4*16*8]; + __shared__ T smem_C[4*32*8]; - if(threadIdx.x < 8) - smem_C[threadIdx.x] = T(0); + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; + + wmma::fill_fragment(c_frag, 0.0f); + + + for(int i = threadIdx.x; i < 32*16*4; i+=blockDim.x) + smem_A[i] = T(0); + + for(int i = threadIdx.x; i < 32*8*4; i+=blockDim.x) + smem_B[i] = T(0); + + for(int i = threadIdx.x; i < 32*8*THREADS/32; i+=blockDim.x) + smem_C[i] = T(0); __syncthreads(); #pragma unroll 8 for(int k = 0; k < 8; k++) local_C[k] = T(0); - - for(int idx = threadIdx.x*128/BITS; idx < K; idx+=blockDim.x*128/BITS) + int block_idx = 0; + //for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x) + for(int base_idx = 0; base_idx < K; base_idx+=64) { - vector_load(local_A, A, idx, idx, K); - for(int col = 0; col < 8; col++) + int tidx = threadIdx.x*4; + + if(base_idx % (4*blockDim.x) == 0) { - int offset_B = (col_offset+col)*ldb; - vector_load(local_B, B, offset_B+idx, idx, K); - - #pragma unroll 128/BITS - for(int k = 0; k < 128/BITS; k++) - local_C[col] += local_A[k]*local_B[k]; + vector_load(local_A, A, base_idx+tidx, base_idx+tidx, K); // 54 mu + block_idx = 0; } - } - #pragma unroll 8 - for(int k = 0; k < 8; k++) - { - local_C[k] = BlockReduce(reduce).Reduce(local_C[k], cub::Sum()); + for(int k = 0; k < 4; k++) + { + if((threadIdx.x >= block_idx*16) && (threadIdx.x < (block_idx+1)*16)) + smem_A[(threadIdx.x % 16) + (32*16*k)] = local_A[k]; // 54 mu + } + block_idx += 1; + + // 4 warps, 1 warps loads in total 4*32=64 values -> 4 columns at a time + // we need 8 columns, so 2 loads and smem stores + // we need a half-warp to load one column at a time + for(int j = 0; j < 2; j++) + { + int col = warp_id + (j*4); + int offset_B = (col_offset+col)*ldb; + vector_load(local_B, B, offset_B+base_idx+warp_lane*4, base_idx+warp_lane*4, K); // 171 mu + + + //#pragma unroll 4 + //for(int k = 0; k < 4; k++) + // if((float)local_B[k] != 0.0) + // printf("%i %i %i %i %f\n", j, warp_id, warp_lane, k, (float)local_B[k]); + + // load and store is different + // we wnat to load 64 consequitive values with one warp + // but we need to store those across 4 fragments since + // the max column width is 16. + + // each 16 values a new tile for each warp + //int tile_idx = warp_lane/16; + #pragma unroll 4 + for(int k = 0; k < 4; k++) + smem_B[(warp_lane % 16) + (col*16) + (k*16*8)] = local_B[k]; // 171 mu + } + + + __syncthreads(); + + //if(threadIdx.x == 0) + // for(int w = 0; w < 4; w++) + // for(int trow = 0; trow < 32; trow++) + // for(int tcol = 0; tcol < 16; tcol++) + // if((float)smem_A[trow + tcol*32 + (w*32*16)] != 0.0) + // printf("A %i %i %i = %f\n", w, trow, tcol, (float) smem_B[trow + tcol*16]); + + //if(threadIdx.x == 0) + // for(int w = 0; w < 4; w++) + // for(int trow = 0; trow < 16; trow++) + // for(int tcol = 0; tcol < 8; tcol++) + // if((float)smem_B[trow + tcol*16 + (w*16*8)] != 0.0) + // printf("B %i %i %i = %f\n", w, trow, tcol, (float) smem_B[trow + tcol*16]); + + + //__syncthreads(); + + wmma::load_matrix_sync(a_frag, &(smem_A[warp_id*32*16]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[warp_id*16*8]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); } - if(threadIdx.x == 0) - { - #pragma unroll 8 - for(int k = 0; k < 8; k++) - smem_C[k] = local_C[k]; - } - else if(threadIdx.x >= 32) - // early return for unused warps - return; + // 129 mu + wmma::store_matrix_sync(&(smem_C[warp_id*32*8]), c_frag, 8, wmma::mem_row_major); + __syncthreads(); - __syncwarp(); + //if(threadIdx.x >= 16){ return; } + //printf("%i %f\n", threadIdx.x, (float)smem_C[threadIdx.x]); + //if(threadIdx.x < 32) + if(warp_lane < 8 && warp_id > 0) + //local_C[warp_lane] = smem_C[warp_lane + (warp_id*32*8)]; + atomicAdd(&(smem_C[warp_lane]), smem_C[warp_lane + (warp_id*32*8)]); + __syncthreads(); + //local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum()); + //if(threadIdx.x == 0) + // for(int row = 0; row < 32; row++) + // { + // printf("row %i ", row); + // for(int id = 0; id < 4; id++) + // { + // printf(" id %i: ", id); + // for(int k = 0; k < 8; k++) + // printf("%f ", (float)smem_C[k + (row*8) + (id*32*8)]); + // printf("\n"); + // } + // } + + //__syncthreads(); + + //if((float)local_C[0] !=0.0f) + // printf("%i %i %f\n", warp_lane, warp_id, (float)local_C[0]); + //local_C[0] = WarpReduce(temp_storage).Sum(local_C[0]); + + //__syncwarp(); + + ////for(int i = threadIdx.x; i < 32*8; i+=blockDim.x) + ////{ + // if((float)local_C[0] !=0.0f) + // printf("%i %f\n", 0, (float)local_C[0]); + //} + + //if(threadIdx.x < 8 && col_offset + threadIdx.x < M) + //out[col_offset + threadIdx.x ] = smem_C[threadIdx.x]; if(threadIdx.x < 8 && col_offset + threadIdx.x < M) - out[col_offset + threadIdx.x ] = smem_C[threadIdx.x]; + out[col_offset + threadIdx.x] = smem_C[threadIdx.x]; } template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) @@ -3378,12 +3483,16 @@ __global__ void with_staging_unified(float const* global_in, float * global_out, // half alpha, half beta); // these are not used and make no sense, but the compiler needs them -template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); // these are not used and make no sense, but the compiler needs them -template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); diff --git a/csrc/ops.cu b/csrc/ops.cu index 07e7107..d83fc6e 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -678,7 +678,6 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size) template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) { - dim3 dimBlock(128); int num_blocks = (m+7)/8; cout << num_blocks << endl; @@ -689,16 +688,17 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out cout << m << endl; cout << n << endl; cout << k << endl; - if(bits == 32) - gemm_device<<< num_blocks, dimBlock, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - else if(bits == 16) - gemm_device<<< num_blocks, dimBlock, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //if(bits == 32) + //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + if(bits == 16) + gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); } template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) { - dim3 dimBlock(128); int num_blocks = (m+7)/8; cout << num_blocks << endl; @@ -709,7 +709,8 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi cout << m << endl; cout << n << endl; cout << k << endl; - kgemm_4bit_inference<<< num_blocks, dimBlock, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + kgemm_4bit_inference<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } //============================================================== @@ -717,7 +718,7 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi //============================================================== template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); +//template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index bdf821c..26f16f2 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -20,8 +20,8 @@ void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimat void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles(A, code, offset, n); } -void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) -{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 32); } +//void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) +//{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 32); } void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) { gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 16); } @@ -316,8 +316,8 @@ extern "C" void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); } void cpipeline_test(float *A, float *B, size_t n, size_t batch_size){ pipeline_test(A, B, n, batch_size); } - void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) - { gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); } + //void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) + //{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); } void cgemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) { gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); } diff --git a/tests/test_functional.py b/tests/test_functional.py index f58cd43..e2ecdcb 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2358,6 +2358,8 @@ def test_cutlass3_gemm(dtype): for i in range(1): #A = torch.rand(2, 4092, dtype=dtype, device='cuda') #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') + #A = torch.rand(1, 4096, dtype=dtype, device='cuda') + #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') A = torch.rand(1, 4096, dtype=dtype, device='cuda') B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')