diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 67f9a3c..3310285 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2947,117 +2947,212 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * //// 9. write outputs to matmul output matrix //} - #define ROWS 2 -template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc) +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) { -// 0. We want to fill a 8x128 tile for a thread block so we have 8x16 tile for each warp -// 1. Load dataB into register -// 2. Dequantize B -// 3. Fetch data from A and multiply - typedef cub::BlockLoad LoadA; - //__shared__ typename LoadA::TempStorage loada; - typedef cub::BlockLoad LoadB; - //__shared__ typename LoadB::TempStorage loadb; typedef cub::BlockReduce BlockReduce; - // Allocate shared memory for BlockReduce - //__shared__ typename BlockReduce::TempStorage reduce; + __shared__ typename BlockReduce::TempStorage reduce; + int col_offset = blockIdx.x *8; - __shared__ union { - typename BlockReduce::TempStorage reduce; - typename LoadB::TempStorage loadb; - typename LoadA::TempStorage loada; - } temp_storage; + T local_A[8]; + T local_B[8]; + T local_C[8]; + __shared__ T smem_C[8]; - T dataA[ITEMS]; - T local_B[ITEMS]; - T local_accC[ROWS]; - int valid_items = 0; - const int col_offset = blockIdx.x * 8; - - __shared__ T tileA[ROWS*THREADS*ITEMS]; - __shared__ T accumulatorC[ROWS*8]; - - //#pragma unroll 8 - //for(int i = 0; i < 8; i++) - // tileA[threadIdx.x + (i*256)] = 0.0f; - //__syncthreads(); - if(threadIdx.x < 64) - accumulatorC[threadIdx.x] = 0.0f; + if(threadIdx.x < 8) + smem_C[threadIdx.x] = T(0); __syncthreads(); + #pragma unroll 8 + for(int k = 0; k < 8; k++) + local_C[k] = T(0); - for(int inner_idx = 0; inner_idx < K; inner_idx+= THREADS*ITEMS) - { - valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; - int baserow = 0; - for(int row = baserow; row < (baserow+ROWS) && row < N; row++) - { - LoadA(temp_storage.loada).Load(&(A[(row*K) + inner_idx]), dataA, valid_items, 0.0f); - #pragma unroll ITEMS - for(int k = 0; k < ITEMS; k++) - tileA[row*THREADS*ITEMS + threadIdx.x + (k*THREADS)] = dataA[k]; + for(int idx = threadIdx.x*8; idx < K; idx+=blockDim.x*8) + { - __syncthreads(); - } - baserow += ROWS; - - // load 16 columns from B at a time. B is transposed, so its like loading rows - // each warp loads one row - // each thread loads 128 byte - - // col: inner_idx + warp_lane - // row: ldb*(offset + warp_id) - for(int col = 0; col < 8 && (col_offset + col) < M; col++) + if(idx + 8 <= K) + reinterpret_cast(local_A)[0] = reinterpret_cast(A)[idx/8]; + else { - int colB = col_offset + col; - - for(int k = 0; k < ROWS; k++) - local_accC[k] = 0.0f; - - int base_idxB = ldb*colB; - valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; - LoadB(temp_storage.loadb).Load(&(B[base_idxB + inner_idx]), local_B, valid_items, 0.0f); - __syncthreads(); - - for(int row = 0; row < ROWS && row < N; row++) + for(int k = 0; k < 8; k++) { - #pragma unroll ITEMS - for(int k = 0; k < ITEMS; k++) - { - int idxA = row*THREADS*ITEMS + threadIdx.x + (THREADS*k); - local_accC[row] += tileA[idxA]*local_B[k]; - } - - local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum()); - if(threadIdx.x == 0) - atomicAdd(&accumulatorC[row*8 + col], local_accC[row]); + if(idx + k < K) + local_A[k] = A[idx+k]; + else + local_A[k] = 0.0f; } } - } - for(int row = 0; row < ROWS && row < N; row++) - { - int out_idx = ldc*row + col_offset; - //if(threadIdx.x < 8) - // if(accumulatorC[row*8 + threadIdx.x] != 0.0) - // printf("%i %i %i %i %f idx %i %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx, blockIdx.x); - - if(threadIdx.x < 8 && (col_offset + threadIdx.x) < M) + for(int col = 0; col < 8; col++) { - //printf("%i %i %i %i %f idx %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx); - out[out_idx + threadIdx.x] = accumulatorC[row*8 + threadIdx.x]; + int offset_B = (col_offset+col)*ldb; + if(idx + 8 <= K) + reinterpret_cast(local_B)[0] = reinterpret_cast(B)[(offset_B+idx)/8]; + else + { + for(int k = 0; k < 8; k++) + { + if(idx + k < K) + local_B[k] = B[(offset_B+idx)+k]; + else + local_B[k] = 0.0f; + } + } + + #pragma unroll 8 + for(int k = 0; k < 8; k++) + { + local_C[col] += local_A[k]*local_B[k]; + //if((float)local_A[k] != 0.0 && (float)local_B[k] != 0.0) + // printf("%i %i %f %f %f\n", k, threadIdx.x, (float)local_A[k], (float)local_B[k], (float)local_C[col]); + } + } } + #pragma unroll 8 + for(int k = 0; k < 8; k++) + { + local_C[k] = BlockReduce(reduce).Reduce(local_C[k], cub::Sum()); + __syncthreads(); + } + + if(threadIdx.x == 0) + #pragma unroll 8 + for(int k = 0; k < 8; k++) + smem_C[k] = local_C[k]; + else if(threadIdx.x >= 32) + // early return for unused warps + return; + + __syncwarp(); + + + //for(int k = 0; k < 8; k++) + // if((float)local_C[k] != 0.0f) + // printf("%i %f\n", threadIdx.x, (float)local_C[k]); + + if(threadIdx.x < 8 && col_offset + threadIdx.x < M) + out[col_offset + threadIdx.x ] = smem_C[threadIdx.x]; + } +//#define ROWS 2 +//template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc) +//{ +//// 0. We want to fill a 8x128 tile for a thread block so we have 8x16 tile for each warp +//// 1. Load dataB into register +//// 2. Dequantize B +//// 3. Fetch data from A and multiply +// +// typedef cub::BlockLoad LoadA; +// //__shared__ typename LoadA::TempStorage loada; +// typedef cub::BlockLoad LoadB; +// //__shared__ typename LoadB::TempStorage loadb; +// typedef cub::BlockReduce BlockReduce; +// // Allocate shared memory for BlockReduce +// //__shared__ typename BlockReduce::TempStorage reduce; +// +// __shared__ union { +// typename BlockReduce::TempStorage reduce; +// typename LoadB::TempStorage loadb; +// typename LoadA::TempStorage loada; +// } temp_storage; +// +// +// T dataA[ITEMS]; +// T local_B[ITEMS]; +// T local_accC[ROWS]; +// int valid_items = 0; +// const int col_offset = blockIdx.x * 8; +// +// __shared__ T tileA[ROWS*THREADS*ITEMS]; +// __shared__ T accumulatorC[ROWS*8]; +// +// //#pragma unroll 8 +// //for(int i = 0; i < 8; i++) +// // tileA[threadIdx.x + (i*256)] = 0.0f; +// //__syncthreads(); +// if(threadIdx.x < 64) +// accumulatorC[threadIdx.x] = 0.0f; +// __syncthreads(); +// +// +// for(int inner_idx = 0; inner_idx < K; inner_idx+= THREADS*ITEMS) +// { +// valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; +// int baserow = 0; +// for(int row = baserow; row < (baserow+ROWS) && row < N; row++) +// { +// LoadA(temp_storage.loada).Load(&(A[(row*K) + inner_idx]), dataA, valid_items, 0.0f); +// +// #pragma unroll ITEMS +// for(int k = 0; k < ITEMS; k++) +// tileA[row*THREADS*ITEMS + threadIdx.x + (k*THREADS)] = dataA[k]; +// +// __syncthreads(); +// } +// baserow += ROWS; +// +// // load 16 columns from B at a time. B is transposed, so its like loading rows +// // each warp loads one row +// // each thread loads 128 byte +// +// // col: inner_idx + warp_lane +// // row: ldb*(offset + warp_id) +// for(int col = 0; col < 8 && (col_offset + col) < M; col++) +// { +// int colB = col_offset + col; +// +// for(int k = 0; k < ROWS; k++) +// local_accC[k] = 0.0f; +// +// int base_idxB = ldb*colB; +// valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; +// LoadB(temp_storage.loadb).Load(&(B[base_idxB + inner_idx]), local_B, valid_items, 0.0f); +// __syncthreads(); +// +// for(int row = 0; row < ROWS && row < N; row++) +// { +// #pragma unroll ITEMS +// for(int k = 0; k < ITEMS; k++) +// { +// int idxA = row*THREADS*ITEMS + threadIdx.x + (THREADS*k); +// local_accC[row] += tileA[idxA]*local_B[k]; +// } +// +// local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum()); +// if(threadIdx.x == 0) +// atomicAdd(&accumulatorC[row*8 + col], local_accC[row]); +// } +// } +// } +// +// for(int row = 0; row < ROWS && row < N; row++) +// { +// int out_idx = ldc*row + col_offset; +// +// //if(threadIdx.x < 8) +// // if(accumulatorC[row*8 + threadIdx.x] != 0.0) +// // printf("%i %i %i %i %f idx %i %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx, blockIdx.x); +// +// if(threadIdx.x < 8 && (col_offset + threadIdx.x) < M) +// { +// //printf("%i %i %i %i %f idx %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx); +// out[out_idx + threadIdx.x] = accumulatorC[row*8 + threadIdx.x]; +// } +// } +// +// +// +//} + __device__ void compute(float* global_out, float const* shared_in) { @@ -3122,10 +3217,8 @@ __global__ void with_staging_unified(float const* global_in, float * global_out, // TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, // TC * out, CStride dC, CBlockLayout , CThreadLayout tC, // half alpha, half beta); -template __global__ void gemm_device(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half const* A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half const* A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); //template __global__ void kMatmul_inference_4bit(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 9603e93..23ecf45 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -138,6 +138,6 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * template __global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz); -template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); #endif diff --git a/csrc/ops.cu b/csrc/ops.cu index aa3dacf..c0c2658 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -675,10 +675,10 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size) -template void gemm_host(int m, int n, int k, T const* A, T* B, T * out, int lda, int ldb, int ldc) +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc) { - dim3 dimBlock(256); + dim3 dimBlock(128); int num_blocks = (m+7)/8; cout << num_blocks << endl; @@ -689,7 +689,7 @@ template void gemm_host(int m, int n, int k, T const* A, T* B, T cout << m << endl; cout << n << endl; cout << k << endl; - gemm_device + gemm_device <<< num_blocks, dimBlock, 0, 0 >>> (m, n, k, A, @@ -701,8 +701,8 @@ template void gemm_host(int m, int n, int k, T const* A, T* B, T // TEMPLATE DEFINITIONS //============================================================== -template void gemm_host(int m, int n, int k, float const* A, float* B, float * out, int lda, int ldb, int ldc); -template void gemm_host(int m, int n, int k, half const* A, half* B, half * out, int lda, int ldb, int ldc); +template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc); +template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index b7ef9a3..8822640 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -190,7 +190,7 @@ template void extractOutliers(char * A, int *idx, char *out, int id void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); -template void gemm_host(int m, int n, int k, T const* A, T* B, T * out, int lda, int ldb, int ldc); +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc); void pipeline_test(float *A, float *B, size_t n, size_t batch_size); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 3dd0b05..f92b52f 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -20,9 +20,9 @@ void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimat void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles(A, code, offset, n); } -void gemm_host_fp32(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc) +void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) { gemm_host(M, N, K, A, B, out, lda, ldb, ldc); } -void gemm_host_fp16(int M, int N, int K, half const* A, half* B, half * out, int lda, int ldb, int ldc) +void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) { gemm_host(M, N, K, A, B, out, lda, ldb, ldc); } @@ -313,10 +313,10 @@ extern "C" void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); } void cpipeline_test(float *A, float *B, size_t n, size_t batch_size){ pipeline_test(A, B, n, batch_size); } - void cgemm_host_fp32(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc) + void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) { gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); } - void cgemm_host_fp16(int M, int N, int K, half const* A, half* B, half * out, int lda, int ldb, int ldc) + void cgemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) { gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); } #endif diff --git a/tests/test_functional.py b/tests/test_functional.py index 1564306..f08c4a2 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2355,11 +2355,11 @@ def test_normal_map_tree(): #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) def test_cutlass3_gemm(dtype): - for i in range(2): - A = torch.rand(2, 4092, dtype=dtype, device='cuda') - B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') - #A = torch.rand(2, 4, dtype=dtype, device='cuda') - #B = torch.rand(4, 4, dtype=dtype, device='cuda') + for i in range(1): + #A = torch.rand(2, 4092, dtype=dtype, device='cuda') + #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') + A = torch.rand(1, 4096, dtype=dtype, device='cuda') + B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') #print('') #print(A) @@ -2371,7 +2371,7 @@ def test_cutlass3_gemm(dtype): #print(C1) #print(C2) - #torch.testing.assert_close(C1, C2) + torch.testing.assert_close(C1, C2, atol=1e-05, rtol=0.005) def test_pipeline_func():