Added template refactor.

This commit is contained in:
Tim Dettmers 2023-04-28 17:34:08 -07:00
parent c1bfb210c5
commit 3aef78342a
6 changed files with 20 additions and 50 deletions

View File

@ -1464,9 +1464,7 @@ def cutlass3_gemm(
lda = ct.c_int32(lda) lda = ct.c_int32(lda)
ldb = ct.c_int32(ldb) ldb = ct.c_int32(ldb)
ldc = ct.c_int32(ldc) ldc = ct.c_int32(ldc)
alpha = ct.c_float(1.0) lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc)
beta = ct.c_float(0.0)
lib.ccutlass_gemm(m, n, k, alpha, get_ptr(A), lda, get_ptr(B), ldb, beta, get_ptr(out), ldc)
return out return out

View File

@ -2949,22 +2949,18 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
#define ROWS 2 #define ROWS 2
__global__ void gemm_device(int M, int N, int K, template <typename T> __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc)
float const* A,
float* B,
float * out, int lda, int ldb, int ldc,
float alpha, float beta)
{ {
// 0. We want to fill a 8x128 tile for a thread block so we have 8x16 tile for each warp // 0. We want to fill a 8x128 tile for a thread block so we have 8x16 tile for each warp
// 1. Load dataB into register // 1. Load dataB into register
// 2. Dequantize B // 2. Dequantize B
// 3. Fetch data from A and multiply // 3. Fetch data from A and multiply
typedef cub::BlockLoad<float, 256 , 4, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadA; typedef cub::BlockLoad<T, 256 , 4, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadA;
//__shared__ typename LoadA::TempStorage loada; //__shared__ typename LoadA::TempStorage loada;
typedef cub::BlockLoad<float, 256 , 4, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadB; typedef cub::BlockLoad<T, 256 , 4, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadB;
//__shared__ typename LoadB::TempStorage loadb; //__shared__ typename LoadB::TempStorage loadb;
typedef cub::BlockReduce<float, 256> BlockReduce; typedef cub::BlockReduce<T, 256> BlockReduce;
// Allocate shared memory for BlockReduce // Allocate shared memory for BlockReduce
//__shared__ typename BlockReduce::TempStorage reduce; //__shared__ typename BlockReduce::TempStorage reduce;
@ -2975,16 +2971,16 @@ __global__ void gemm_device(int M, int N, int K,
} temp_storage; } temp_storage;
float dataA[4]; T dataA[4];
float local_B[4]; T local_B[4];
float local_accC[ROWS]; T local_accC[ROWS];
int valid_items = 0; int valid_items = 0;
const int warp_id = threadIdx.x/32; const int warp_id = threadIdx.x/32;
const int warp_lane = threadIdx.x % 32; const int warp_lane = threadIdx.x % 32;
const int col_offset = blockIdx.x * 8; const int col_offset = blockIdx.x * 8;
__shared__ float tileA[ROWS*1024]; __shared__ T tileA[ROWS*1024];
__shared__ float accumulatorC[ROWS*8]; __shared__ T accumulatorC[ROWS*8];
//#pragma unroll 8 //#pragma unroll 8
//for(int i = 0; i < 8; i++) //for(int i = 0; i < 8; i++)
@ -3128,6 +3124,7 @@ __global__ void with_staging_unified(float const* global_in, float * global_out,
// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, // TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
// TC * out, CStride dC, CBlockLayout , CThreadLayout tC, // TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
// half alpha, half beta); // half alpha, half beta);
template __global__ void gemm_device<float>(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc);
//template __global__ void kMatmul_inference_4bit<NF4, half, half, half>(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB); //template __global__ void kMatmul_inference_4bit<NF4, half, half, half>(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB);

View File

@ -138,10 +138,6 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
template <size_t stages_count /* Pipeline with stages_count stages */> template <size_t stages_count /* Pipeline with stages_count stages */>
__global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz); __global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz);
__global__ void gemm_device(int M, int N, int K, template <typename T> __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc);
float const* A,
float * B,
float * out, int lda, int ldb, int ldc,
float alpha, float beta);
#endif #endif

View File

@ -675,12 +675,7 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size)
void gemm_host(int m, int n, int k, template <typename T> void gemm_host(int m, int n, int k, T const* A, T* B, T * out, int lda, int ldb, int ldc)
float alpha,
float const* A, int lda,
float * B, int ldb,
float beta,
float * C, int ldc)
{ {
dim3 dimBlock(256); dim3 dimBlock(256);
@ -699,14 +694,14 @@ void gemm_host(int m, int n, int k,
(m, n, k, (m, n, k,
A, A,
B, B,
C, lda, ldb, ldc, out, lda, ldb, ldc);
alpha, beta);
} }
//============================================================== //==============================================================
// TEMPLATE DEFINITIONS // TEMPLATE DEFINITIONS
//============================================================== //==============================================================
template void gemm_host<float>(int m, int n, int k, float const* A, float* B, float * out, int lda, int ldb, int ldc);
template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols); template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
template void extractOutliers<COL_AMPERE>(char * A, int *idx, char *out, int idx_size, int rows, int cols); template void extractOutliers<COL_AMPERE>(char * A, int *idx, char *out, int idx_size, int rows, int cols);

View File

@ -190,12 +190,7 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB);
void gemm_host(int m, int n, int k, template <typename T> void gemm_host(int m, int n, int k, T const* A, T* B, T * out, int lda, int ldb, int ldc);
float alpha,
float const* A, int ldA,
float * B, int ldB,
float beta,
float * C, int ldC);
void pipeline_test(float *A, float *B, size_t n, size_t batch_size); void pipeline_test(float *A, float *B, size_t n, size_t batch_size);

View File

@ -20,14 +20,8 @@ void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimat
void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles<half>(A, code, offset, n); } void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles<half>(A, code, offset, n); }
void void gemm_host_fp32(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc)
cppgemm(int m, int n, int k, { gemm_host<float>(M, N, K, A, B, out, lda, ldb, ldc); }
float alpha,
float const* A, int ldA,
float * B, int ldB,
float beta,
float * C, int ldC)
{ gemm_host(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);}
#define MAKE_FUNC32(fname, oname, gtype, gbits) \ #define MAKE_FUNC32(fname, oname, gtype, gbits) \
@ -317,13 +311,8 @@ extern "C"
void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); } void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); }
void cpipeline_test(float *A, float *B, size_t n, size_t batch_size){ pipeline_test(A, B, n, batch_size); } void cpipeline_test(float *A, float *B, size_t n, size_t batch_size){ pipeline_test(A, B, n, batch_size); }
void ccutlass_gemm(int m, int n, int k, void cgemm_host_fp32(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc)
float alpha, { gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); }
float const* A, int ldA,
float * B, int ldB,
float beta,
float * C, int ldC)
{ cppgemm(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);}
#endif #endif
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }