diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 3310285..a5697ee 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2947,16 +2947,31 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * //// 9. write outputs to matmul output matrix //} -#define ROWS 2 -template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) +template __device__ inline void vector_load(T *local, T * __restrict__ const buffer, int idx, int limit_base, int limit) +{ + if(limit_base + ITEMS <= limit) + reinterpret_cast(local)[0] = reinterpret_cast(buffer)[idx/ITEMS]; + else + { + for(int k = 0; k < ITEMS; k++) + { + if(limit_base + k < limit) + local[k] = buffer[idx+k]; + else + local[k] = 0.0f; + } + } +} + +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) { typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage reduce; int col_offset = blockIdx.x *8; - T local_A[8]; - T local_B[8]; + T local_A[128/BITS]; + T local_B[128/BITS]; T local_C[8]; __shared__ T smem_C[8]; @@ -2970,47 +2985,18 @@ template __global__ void gemm_device(int M, local_C[k] = T(0); - for(int idx = threadIdx.x*8; idx < K; idx+=blockDim.x*8) + for(int idx = threadIdx.x*128/BITS; idx < K; idx+=blockDim.x*128/BITS) { - - if(idx + 8 <= K) - reinterpret_cast(local_A)[0] = reinterpret_cast(A)[idx/8]; - else - { - for(int k = 0; k < 8; k++) - { - if(idx + k < K) - local_A[k] = A[idx+k]; - else - local_A[k] = 0.0f; - } - } - + vector_load(local_A, A, idx, idx, K); for(int col = 0; col < 8; col++) { int offset_B = (col_offset+col)*ldb; - if(idx + 8 <= K) - reinterpret_cast(local_B)[0] = reinterpret_cast(B)[(offset_B+idx)/8]; - else - { - for(int k = 0; k < 8; k++) - { - if(idx + k < K) - local_B[k] = B[(offset_B+idx)+k]; - else - local_B[k] = 0.0f; - } - } + vector_load(local_B, B, offset_B+idx, idx, K); - #pragma unroll 8 - for(int k = 0; k < 8; k++) - { + #pragma unroll 128/BITS + for(int k = 0; k < 128/BITS; k++) local_C[col] += local_A[k]*local_B[k]; - //if((float)local_A[k] != 0.0 && (float)local_B[k] != 0.0) - // printf("%i %i %f %f %f\n", k, threadIdx.x, (float)local_A[k], (float)local_B[k], (float)local_C[col]); - } - } } @@ -3022,9 +3008,11 @@ template __global__ void gemm_device(int M, } if(threadIdx.x == 0) + { #pragma unroll 8 for(int k = 0; k < 8; k++) smem_C[k] = local_C[k]; + } else if(threadIdx.x >= 32) // early return for unused warps return; @@ -3032,15 +3020,8 @@ template __global__ void gemm_device(int M, __syncwarp(); - //for(int k = 0; k < 8; k++) - // if((float)local_C[k] != 0.0f) - // printf("%i %f\n", threadIdx.x, (float)local_C[k]); - if(threadIdx.x < 8 && col_offset + threadIdx.x < M) out[col_offset + threadIdx.x ] = smem_C[threadIdx.x]; - - - } //#define ROWS 2 @@ -3217,7 +3198,13 @@ __global__ void with_staging_unified(float const* global_in, float * global_out, // TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, // TC * out, CStride dC, CBlockLayout , CThreadLayout tC, // half alpha, half beta); + +// these are not used and make no sense, but the compiler needs them template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +// these are not used and make no sense, but the compiler needs them + +template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 23ecf45..aab7b95 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -138,6 +138,6 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * template __global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz); -template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); #endif diff --git a/csrc/ops.cu b/csrc/ops.cu index c0c2658..2219690 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -675,7 +675,7 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size) -template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc) +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) { dim3 dimBlock(128); @@ -689,20 +689,18 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out cout << m << endl; cout << n << endl; cout << k << endl; - gemm_device - <<< num_blocks, dimBlock, 0, 0 >>> - (m, n, k, - A, - B, - out, lda, ldb, ldc); + if(bits == 32) + gemm_device<<< num_blocks, dimBlock, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + else if(bits == 16) + gemm_device<<< num_blocks, dimBlock, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); } //============================================================== // TEMPLATE DEFINITIONS //============================================================== -template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc); -template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc); +template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); +template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 8822640..ffc9e87 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -190,7 +190,7 @@ template void extractOutliers(char * A, int *idx, char *out, int id void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); -template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc); +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); void pipeline_test(float *A, float *B, size_t n, size_t batch_size); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index f92b52f..1ece3e6 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -21,9 +21,9 @@ void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimate void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) -{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc); } +{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 32); } void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) -{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc); } +{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 16); } #define MAKE_FUNC32(fname, oname, gtype, gbits) \ diff --git a/tests/test_functional.py b/tests/test_functional.py index f08c4a2..b256af9 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2352,8 +2352,8 @@ def test_normal_map_tree(): print(pivots) -#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) -@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) +#@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) def test_cutlass3_gemm(dtype): for i in range(1): #A = torch.rand(2, 4092, dtype=dtype, device='cuda')