New implementation for batch size 1.
This commit is contained in:
parent
f6df4aef6a
commit
f3e97ccbd2
271
csrc/kernels.cu
271
csrc/kernels.cu
|
@ -2947,117 +2947,212 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
|
||||||
//// 9. write outputs to matmul output matrix
|
//// 9. write outputs to matmul output matrix
|
||||||
//}
|
//}
|
||||||
|
|
||||||
|
|
||||||
#define ROWS 2
|
#define ROWS 2
|
||||||
template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc)
|
template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
|
||||||
{
|
{
|
||||||
// 0. We want to fill a 8x128 tile for a thread block so we have 8x16 tile for each warp
|
|
||||||
// 1. Load dataB into register
|
|
||||||
// 2. Dequantize B
|
|
||||||
// 3. Fetch data from A and multiply
|
|
||||||
|
|
||||||
typedef cub::BlockLoad<T, THREADS , ITEMS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadA;
|
|
||||||
//__shared__ typename LoadA::TempStorage loada;
|
|
||||||
typedef cub::BlockLoad<T, THREADS , ITEMS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadB;
|
|
||||||
//__shared__ typename LoadB::TempStorage loadb;
|
|
||||||
typedef cub::BlockReduce<T, THREADS> BlockReduce;
|
typedef cub::BlockReduce<T, THREADS> BlockReduce;
|
||||||
// Allocate shared memory for BlockReduce
|
__shared__ typename BlockReduce::TempStorage reduce;
|
||||||
//__shared__ typename BlockReduce::TempStorage reduce;
|
int col_offset = blockIdx.x *8;
|
||||||
|
|
||||||
__shared__ union {
|
T local_A[8];
|
||||||
typename BlockReduce::TempStorage reduce;
|
T local_B[8];
|
||||||
typename LoadB::TempStorage loadb;
|
T local_C[8];
|
||||||
typename LoadA::TempStorage loada;
|
|
||||||
} temp_storage;
|
|
||||||
|
|
||||||
|
__shared__ T smem_C[8];
|
||||||
|
|
||||||
T dataA[ITEMS];
|
if(threadIdx.x < 8)
|
||||||
T local_B[ITEMS];
|
smem_C[threadIdx.x] = T(0);
|
||||||
T local_accC[ROWS];
|
|
||||||
int valid_items = 0;
|
|
||||||
const int col_offset = blockIdx.x * 8;
|
|
||||||
|
|
||||||
__shared__ T tileA[ROWS*THREADS*ITEMS];
|
|
||||||
__shared__ T accumulatorC[ROWS*8];
|
|
||||||
|
|
||||||
//#pragma unroll 8
|
|
||||||
//for(int i = 0; i < 8; i++)
|
|
||||||
// tileA[threadIdx.x + (i*256)] = 0.0f;
|
|
||||||
//__syncthreads();
|
|
||||||
if(threadIdx.x < 64)
|
|
||||||
accumulatorC[threadIdx.x] = 0.0f;
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
#pragma unroll 8
|
||||||
|
for(int k = 0; k < 8; k++)
|
||||||
|
local_C[k] = T(0);
|
||||||
|
|
||||||
for(int inner_idx = 0; inner_idx < K; inner_idx+= THREADS*ITEMS)
|
|
||||||
{
|
|
||||||
valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx;
|
|
||||||
int baserow = 0;
|
|
||||||
for(int row = baserow; row < (baserow+ROWS) && row < N; row++)
|
|
||||||
{
|
|
||||||
LoadA(temp_storage.loada).Load(&(A[(row*K) + inner_idx]), dataA, valid_items, 0.0f);
|
|
||||||
|
|
||||||
#pragma unroll ITEMS
|
for(int idx = threadIdx.x*8; idx < K; idx+=blockDim.x*8)
|
||||||
for(int k = 0; k < ITEMS; k++)
|
{
|
||||||
tileA[row*THREADS*ITEMS + threadIdx.x + (k*THREADS)] = dataA[k];
|
|
||||||
|
|
||||||
__syncthreads();
|
if(idx + 8 <= K)
|
||||||
}
|
reinterpret_cast<float4(&)[8]>(local_A)[0] = reinterpret_cast<float4*>(A)[idx/8];
|
||||||
baserow += ROWS;
|
else
|
||||||
|
|
||||||
// load 16 columns from B at a time. B is transposed, so its like loading rows
|
|
||||||
// each warp loads one row
|
|
||||||
// each thread loads 128 byte
|
|
||||||
|
|
||||||
// col: inner_idx + warp_lane
|
|
||||||
// row: ldb*(offset + warp_id)
|
|
||||||
for(int col = 0; col < 8 && (col_offset + col) < M; col++)
|
|
||||||
{
|
{
|
||||||
int colB = col_offset + col;
|
for(int k = 0; k < 8; k++)
|
||||||
|
|
||||||
for(int k = 0; k < ROWS; k++)
|
|
||||||
local_accC[k] = 0.0f;
|
|
||||||
|
|
||||||
int base_idxB = ldb*colB;
|
|
||||||
valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx;
|
|
||||||
LoadB(temp_storage.loadb).Load(&(B[base_idxB + inner_idx]), local_B, valid_items, 0.0f);
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
for(int row = 0; row < ROWS && row < N; row++)
|
|
||||||
{
|
{
|
||||||
#pragma unroll ITEMS
|
if(idx + k < K)
|
||||||
for(int k = 0; k < ITEMS; k++)
|
local_A[k] = A[idx+k];
|
||||||
{
|
else
|
||||||
int idxA = row*THREADS*ITEMS + threadIdx.x + (THREADS*k);
|
local_A[k] = 0.0f;
|
||||||
local_accC[row] += tileA[idxA]*local_B[k];
|
|
||||||
}
|
|
||||||
|
|
||||||
local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum());
|
|
||||||
if(threadIdx.x == 0)
|
|
||||||
atomicAdd(&accumulatorC[row*8 + col], local_accC[row]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
for(int row = 0; row < ROWS && row < N; row++)
|
|
||||||
{
|
|
||||||
int out_idx = ldc*row + col_offset;
|
|
||||||
|
|
||||||
//if(threadIdx.x < 8)
|
for(int col = 0; col < 8; col++)
|
||||||
// if(accumulatorC[row*8 + threadIdx.x] != 0.0)
|
|
||||||
// printf("%i %i %i %i %f idx %i %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx, blockIdx.x);
|
|
||||||
|
|
||||||
if(threadIdx.x < 8 && (col_offset + threadIdx.x) < M)
|
|
||||||
{
|
{
|
||||||
//printf("%i %i %i %i %f idx %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx);
|
int offset_B = (col_offset+col)*ldb;
|
||||||
out[out_idx + threadIdx.x] = accumulatorC[row*8 + threadIdx.x];
|
if(idx + 8 <= K)
|
||||||
|
reinterpret_cast<float4(&)[8]>(local_B)[0] = reinterpret_cast<float4*>(B)[(offset_B+idx)/8];
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for(int k = 0; k < 8; k++)
|
||||||
|
{
|
||||||
|
if(idx + k < K)
|
||||||
|
local_B[k] = B[(offset_B+idx)+k];
|
||||||
|
else
|
||||||
|
local_B[k] = 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll 8
|
||||||
|
for(int k = 0; k < 8; k++)
|
||||||
|
{
|
||||||
|
local_C[col] += local_A[k]*local_B[k];
|
||||||
|
//if((float)local_A[k] != 0.0 && (float)local_B[k] != 0.0)
|
||||||
|
// printf("%i %i %f %f %f\n", k, threadIdx.x, (float)local_A[k], (float)local_B[k], (float)local_C[col]);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#pragma unroll 8
|
||||||
|
for(int k = 0; k < 8; k++)
|
||||||
|
{
|
||||||
|
local_C[k] = BlockReduce(reduce).Reduce(local_C[k], cub::Sum());
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
if(threadIdx.x == 0)
|
||||||
|
#pragma unroll 8
|
||||||
|
for(int k = 0; k < 8; k++)
|
||||||
|
smem_C[k] = local_C[k];
|
||||||
|
else if(threadIdx.x >= 32)
|
||||||
|
// early return for unused warps
|
||||||
|
return;
|
||||||
|
|
||||||
|
__syncwarp();
|
||||||
|
|
||||||
|
|
||||||
|
//for(int k = 0; k < 8; k++)
|
||||||
|
// if((float)local_C[k] != 0.0f)
|
||||||
|
// printf("%i %f\n", threadIdx.x, (float)local_C[k]);
|
||||||
|
|
||||||
|
if(threadIdx.x < 8 && col_offset + threadIdx.x < M)
|
||||||
|
out[col_offset + threadIdx.x ] = smem_C[threadIdx.x];
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//#define ROWS 2
|
||||||
|
//template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc)
|
||||||
|
//{
|
||||||
|
//// 0. We want to fill a 8x128 tile for a thread block so we have 8x16 tile for each warp
|
||||||
|
//// 1. Load dataB into register
|
||||||
|
//// 2. Dequantize B
|
||||||
|
//// 3. Fetch data from A and multiply
|
||||||
|
//
|
||||||
|
// typedef cub::BlockLoad<T, THREADS , ITEMS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadA;
|
||||||
|
// //__shared__ typename LoadA::TempStorage loada;
|
||||||
|
// typedef cub::BlockLoad<T, THREADS , ITEMS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadB;
|
||||||
|
// //__shared__ typename LoadB::TempStorage loadb;
|
||||||
|
// typedef cub::BlockReduce<T, THREADS> BlockReduce;
|
||||||
|
// // Allocate shared memory for BlockReduce
|
||||||
|
// //__shared__ typename BlockReduce::TempStorage reduce;
|
||||||
|
//
|
||||||
|
// __shared__ union {
|
||||||
|
// typename BlockReduce::TempStorage reduce;
|
||||||
|
// typename LoadB::TempStorage loadb;
|
||||||
|
// typename LoadA::TempStorage loada;
|
||||||
|
// } temp_storage;
|
||||||
|
//
|
||||||
|
//
|
||||||
|
// T dataA[ITEMS];
|
||||||
|
// T local_B[ITEMS];
|
||||||
|
// T local_accC[ROWS];
|
||||||
|
// int valid_items = 0;
|
||||||
|
// const int col_offset = blockIdx.x * 8;
|
||||||
|
//
|
||||||
|
// __shared__ T tileA[ROWS*THREADS*ITEMS];
|
||||||
|
// __shared__ T accumulatorC[ROWS*8];
|
||||||
|
//
|
||||||
|
// //#pragma unroll 8
|
||||||
|
// //for(int i = 0; i < 8; i++)
|
||||||
|
// // tileA[threadIdx.x + (i*256)] = 0.0f;
|
||||||
|
// //__syncthreads();
|
||||||
|
// if(threadIdx.x < 64)
|
||||||
|
// accumulatorC[threadIdx.x] = 0.0f;
|
||||||
|
// __syncthreads();
|
||||||
|
//
|
||||||
|
//
|
||||||
|
// for(int inner_idx = 0; inner_idx < K; inner_idx+= THREADS*ITEMS)
|
||||||
|
// {
|
||||||
|
// valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx;
|
||||||
|
// int baserow = 0;
|
||||||
|
// for(int row = baserow; row < (baserow+ROWS) && row < N; row++)
|
||||||
|
// {
|
||||||
|
// LoadA(temp_storage.loada).Load(&(A[(row*K) + inner_idx]), dataA, valid_items, 0.0f);
|
||||||
|
//
|
||||||
|
// #pragma unroll ITEMS
|
||||||
|
// for(int k = 0; k < ITEMS; k++)
|
||||||
|
// tileA[row*THREADS*ITEMS + threadIdx.x + (k*THREADS)] = dataA[k];
|
||||||
|
//
|
||||||
|
// __syncthreads();
|
||||||
|
// }
|
||||||
|
// baserow += ROWS;
|
||||||
|
//
|
||||||
|
// // load 16 columns from B at a time. B is transposed, so its like loading rows
|
||||||
|
// // each warp loads one row
|
||||||
|
// // each thread loads 128 byte
|
||||||
|
//
|
||||||
|
// // col: inner_idx + warp_lane
|
||||||
|
// // row: ldb*(offset + warp_id)
|
||||||
|
// for(int col = 0; col < 8 && (col_offset + col) < M; col++)
|
||||||
|
// {
|
||||||
|
// int colB = col_offset + col;
|
||||||
|
//
|
||||||
|
// for(int k = 0; k < ROWS; k++)
|
||||||
|
// local_accC[k] = 0.0f;
|
||||||
|
//
|
||||||
|
// int base_idxB = ldb*colB;
|
||||||
|
// valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx;
|
||||||
|
// LoadB(temp_storage.loadb).Load(&(B[base_idxB + inner_idx]), local_B, valid_items, 0.0f);
|
||||||
|
// __syncthreads();
|
||||||
|
//
|
||||||
|
// for(int row = 0; row < ROWS && row < N; row++)
|
||||||
|
// {
|
||||||
|
// #pragma unroll ITEMS
|
||||||
|
// for(int k = 0; k < ITEMS; k++)
|
||||||
|
// {
|
||||||
|
// int idxA = row*THREADS*ITEMS + threadIdx.x + (THREADS*k);
|
||||||
|
// local_accC[row] += tileA[idxA]*local_B[k];
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum());
|
||||||
|
// if(threadIdx.x == 0)
|
||||||
|
// atomicAdd(&accumulatorC[row*8 + col], local_accC[row]);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// for(int row = 0; row < ROWS && row < N; row++)
|
||||||
|
// {
|
||||||
|
// int out_idx = ldc*row + col_offset;
|
||||||
|
//
|
||||||
|
// //if(threadIdx.x < 8)
|
||||||
|
// // if(accumulatorC[row*8 + threadIdx.x] != 0.0)
|
||||||
|
// // printf("%i %i %i %i %f idx %i %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx, blockIdx.x);
|
||||||
|
//
|
||||||
|
// if(threadIdx.x < 8 && (col_offset + threadIdx.x) < M)
|
||||||
|
// {
|
||||||
|
// //printf("%i %i %i %i %f idx %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx);
|
||||||
|
// out[out_idx + threadIdx.x] = accumulatorC[row*8 + threadIdx.x];
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
//
|
||||||
|
//
|
||||||
|
//}
|
||||||
|
|
||||||
|
|
||||||
__device__ void compute(float* global_out, float const* shared_in)
|
__device__ void compute(float* global_out, float const* shared_in)
|
||||||
{
|
{
|
||||||
|
@ -3122,10 +3217,8 @@ __global__ void with_staging_unified(float const* global_in, float * global_out,
|
||||||
// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
|
// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
|
||||||
// TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
|
// TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
|
||||||
// half alpha, half beta);
|
// half alpha, half beta);
|
||||||
template __global__ void gemm_device<float, 4, 256>(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc);
|
template __global__ void gemm_device<float, 16, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
|
||||||
template __global__ void gemm_device<half, 4, 256>(int M, int N, int K, half const* A, half* B, half * out, int lda, int ldb, int ldc);
|
template __global__ void gemm_device<half, 16, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||||
template __global__ void gemm_device<float, 8, 256>(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc);
|
|
||||||
template __global__ void gemm_device<half, 8, 256>(int M, int N, int K, half const* A, half* B, half * out, int lda, int ldb, int ldc);
|
|
||||||
|
|
||||||
|
|
||||||
//template __global__ void kMatmul_inference_4bit<NF4, half, half, half>(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB);
|
//template __global__ void kMatmul_inference_4bit<NF4, half, half, half>(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB);
|
||||||
|
|
|
@ -138,6 +138,6 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
|
||||||
template <size_t stages_count /* Pipeline with stages_count stages */>
|
template <size_t stages_count /* Pipeline with stages_count stages */>
|
||||||
__global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz);
|
__global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz);
|
||||||
|
|
||||||
template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc);
|
template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
10
csrc/ops.cu
10
csrc/ops.cu
|
@ -675,10 +675,10 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T> void gemm_host(int m, int n, int k, T const* A, T* B, T * out, int lda, int ldb, int ldc)
|
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc)
|
||||||
{
|
{
|
||||||
|
|
||||||
dim3 dimBlock(256);
|
dim3 dimBlock(128);
|
||||||
int num_blocks = (m+7)/8;
|
int num_blocks = (m+7)/8;
|
||||||
|
|
||||||
cout << num_blocks << endl;
|
cout << num_blocks << endl;
|
||||||
|
@ -689,7 +689,7 @@ template <typename T> void gemm_host(int m, int n, int k, T const* A, T* B, T
|
||||||
cout << m << endl;
|
cout << m << endl;
|
||||||
cout << n << endl;
|
cout << n << endl;
|
||||||
cout << k << endl;
|
cout << k << endl;
|
||||||
gemm_device<T, 8, 256>
|
gemm_device<T, 16, 128>
|
||||||
<<< num_blocks, dimBlock, 0, 0 >>>
|
<<< num_blocks, dimBlock, 0, 0 >>>
|
||||||
(m, n, k,
|
(m, n, k,
|
||||||
A,
|
A,
|
||||||
|
@ -701,8 +701,8 @@ template <typename T> void gemm_host(int m, int n, int k, T const* A, T* B, T
|
||||||
// TEMPLATE DEFINITIONS
|
// TEMPLATE DEFINITIONS
|
||||||
//==============================================================
|
//==============================================================
|
||||||
|
|
||||||
template void gemm_host<float>(int m, int n, int k, float const* A, float* B, float * out, int lda, int ldb, int ldc);
|
template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc);
|
||||||
template void gemm_host<half>(int m, int n, int k, half const* A, half* B, half * out, int lda, int ldb, int ldc);
|
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc);
|
||||||
template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
|
template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
|
||||||
template void extractOutliers<COL_AMPERE>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
|
template void extractOutliers<COL_AMPERE>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
|
||||||
|
|
||||||
|
|
|
@ -190,7 +190,7 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
|
||||||
|
|
||||||
void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB);
|
void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB);
|
||||||
|
|
||||||
template <typename T> void gemm_host(int m, int n, int k, T const* A, T* B, T * out, int lda, int ldb, int ldc);
|
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc);
|
||||||
|
|
||||||
|
|
||||||
void pipeline_test(float *A, float *B, size_t n, size_t batch_size);
|
void pipeline_test(float *A, float *B, size_t n, size_t batch_size);
|
||||||
|
|
|
@ -20,9 +20,9 @@ void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimat
|
||||||
void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles<half>(A, code, offset, n); }
|
void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles<half>(A, code, offset, n); }
|
||||||
|
|
||||||
|
|
||||||
void gemm_host_fp32(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc)
|
void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
|
||||||
{ gemm_host<float>(M, N, K, A, B, out, lda, ldb, ldc); }
|
{ gemm_host<float>(M, N, K, A, B, out, lda, ldb, ldc); }
|
||||||
void gemm_host_fp16(int M, int N, int K, half const* A, half* B, half * out, int lda, int ldb, int ldc)
|
void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc)
|
||||||
{ gemm_host<half>(M, N, K, A, B, out, lda, ldb, ldc); }
|
{ gemm_host<half>(M, N, K, A, B, out, lda, ldb, ldc); }
|
||||||
|
|
||||||
|
|
||||||
|
@ -313,10 +313,10 @@ extern "C"
|
||||||
void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); }
|
void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); }
|
||||||
void cpipeline_test(float *A, float *B, size_t n, size_t batch_size){ pipeline_test(A, B, n, batch_size); }
|
void cpipeline_test(float *A, float *B, size_t n, size_t batch_size){ pipeline_test(A, B, n, batch_size); }
|
||||||
|
|
||||||
void cgemm_host_fp32(int M, int N, int K, float const* A, float* B, float * out, int lda, int ldb, int ldc)
|
void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
|
||||||
{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); }
|
{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); }
|
||||||
|
|
||||||
void cgemm_host_fp16(int M, int N, int K, half const* A, half* B, half * out, int lda, int ldb, int ldc)
|
void cgemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc)
|
||||||
{ gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); }
|
{ gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); }
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -2355,11 +2355,11 @@ def test_normal_map_tree():
|
||||||
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
||||||
def test_cutlass3_gemm(dtype):
|
def test_cutlass3_gemm(dtype):
|
||||||
for i in range(2):
|
for i in range(1):
|
||||||
A = torch.rand(2, 4092, dtype=dtype, device='cuda')
|
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
|
||||||
B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
|
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
|
||||||
#A = torch.rand(2, 4, dtype=dtype, device='cuda')
|
A = torch.rand(1, 4096, dtype=dtype, device='cuda')
|
||||||
#B = torch.rand(4, 4, dtype=dtype, device='cuda')
|
B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
|
||||||
|
|
||||||
#print('')
|
#print('')
|
||||||
#print(A)
|
#print(A)
|
||||||
|
@ -2371,7 +2371,7 @@ def test_cutlass3_gemm(dtype):
|
||||||
#print(C1)
|
#print(C1)
|
||||||
#print(C2)
|
#print(C2)
|
||||||
|
|
||||||
#torch.testing.assert_close(C1, C2)
|
torch.testing.assert_close(C1, C2, atol=1e-05, rtol=0.005)
|
||||||
|
|
||||||
|
|
||||||
def test_pipeline_func():
|
def test_pipeline_func():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user