Added bit template.

This commit is contained in:
Tim Dettmers 2023-04-28 22:10:42 -07:00
parent f3e97ccbd2
commit cad839941b
6 changed files with 45 additions and 60 deletions

View File

@ -2947,16 +2947,31 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
//// 9. write outputs to matmul output matrix
//}
#define ROWS 2
template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_load(T *local, T * __restrict__ const buffer, int idx, int limit_base, int limit)
{
if(limit_base + ITEMS <= limit)
reinterpret_cast<TCAST*>(local)[0] = reinterpret_cast<TCAST*>(buffer)[idx/ITEMS];
else
{
for(int k = 0; k < ITEMS; k++)
{
if(limit_base + k < limit)
local[k] = buffer[idx+k];
else
local[k] = 0.0f;
}
}
}
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
{
typedef cub::BlockReduce<T, THREADS> BlockReduce;
__shared__ typename BlockReduce::TempStorage reduce;
int col_offset = blockIdx.x *8;
T local_A[8];
T local_B[8];
T local_A[128/BITS];
T local_B[128/BITS];
T local_C[8];
__shared__ T smem_C[8];
@ -2970,47 +2985,18 @@ template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M,
local_C[k] = T(0);
for(int idx = threadIdx.x*8; idx < K; idx+=blockDim.x*8)
for(int idx = threadIdx.x*128/BITS; idx < K; idx+=blockDim.x*128/BITS)
{
if(idx + 8 <= K)
reinterpret_cast<float4(&)[8]>(local_A)[0] = reinterpret_cast<float4*>(A)[idx/8];
else
{
for(int k = 0; k < 8; k++)
{
if(idx + k < K)
local_A[k] = A[idx+k];
else
local_A[k] = 0.0f;
}
}
vector_load<T, int4, 128/BITS>(local_A, A, idx, idx, K);
for(int col = 0; col < 8; col++)
{
int offset_B = (col_offset+col)*ldb;
if(idx + 8 <= K)
reinterpret_cast<float4(&)[8]>(local_B)[0] = reinterpret_cast<float4*>(B)[(offset_B+idx)/8];
else
{
for(int k = 0; k < 8; k++)
{
if(idx + k < K)
local_B[k] = B[(offset_B+idx)+k];
else
local_B[k] = 0.0f;
}
}
vector_load<T, int4, 128/BITS>(local_B, B, offset_B+idx, idx, K);
#pragma unroll 8
for(int k = 0; k < 8; k++)
{
#pragma unroll 128/BITS
for(int k = 0; k < 128/BITS; k++)
local_C[col] += local_A[k]*local_B[k];
//if((float)local_A[k] != 0.0 && (float)local_B[k] != 0.0)
// printf("%i %i %f %f %f\n", k, threadIdx.x, (float)local_A[k], (float)local_B[k], (float)local_C[col]);
}
}
}
@ -3022,9 +3008,11 @@ template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M,
}
if(threadIdx.x == 0)
{
#pragma unroll 8
for(int k = 0; k < 8; k++)
smem_C[k] = local_C[k];
}
else if(threadIdx.x >= 32)
// early return for unused warps
return;
@ -3032,15 +3020,8 @@ template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M,
__syncwarp();
//for(int k = 0; k < 8; k++)
// if((float)local_C[k] != 0.0f)
// printf("%i %f\n", threadIdx.x, (float)local_C[k]);
if(threadIdx.x < 8 && col_offset + threadIdx.x < M)
out[col_offset + threadIdx.x ] = smem_C[threadIdx.x];
}
//#define ROWS 2
@ -3217,7 +3198,13 @@ __global__ void with_staging_unified(float const* global_in, float * global_out,
// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
// TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
// half alpha, half beta);
// these are not used and make no sense, but the compiler needs them
template __global__ void gemm_device<float, 16, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
// these are not used and make no sense, but the compiler needs them
template __global__ void gemm_device<float, 32, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);

View File

@ -138,6 +138,6 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
template <size_t stages_count /* Pipeline with stages_count stages */>
__global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz);
template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc);
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc);
#endif

View File

@ -675,7 +675,7 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size)
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc)
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits)
{
dim3 dimBlock(128);
@ -689,20 +689,18 @@ template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out
cout << m << endl;
cout << n << endl;
cout << k << endl;
gemm_device<T, 16, 128>
<<< num_blocks, dimBlock, 0, 0 >>>
(m, n, k,
A,
B,
out, lda, ldb, ldc);
if(bits == 32)
gemm_device<T, 32, 128><<< num_blocks, dimBlock, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
else if(bits == 16)
gemm_device<T, 16, 128><<< num_blocks, dimBlock, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
}
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc);
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc);
template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
template void extractOutliers<COL_AMPERE>(char * A, int *idx, char *out, int idx_size, int rows, int cols);

View File

@ -190,7 +190,7 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB);
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc);
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits);
void pipeline_test(float *A, float *B, size_t n, size_t batch_size);

View File

@ -21,9 +21,9 @@ void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimate
void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
{ gemm_host<float>(M, N, K, A, B, out, lda, ldb, ldc); }
{ gemm_host<float>(M, N, K, A, B, out, lda, ldb, ldc, 32); }
void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc)
{ gemm_host<half>(M, N, K, A, B, out, lda, ldb, ldc); }
{ gemm_host<half>(M, N, K, A, B, out, lda, ldb, ldc, 16); }
#define MAKE_FUNC32(fname, oname, gtype, gbits) \

View File

@ -2352,8 +2352,8 @@ def test_normal_map_tree():
print(pivots)
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
#@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
def test_cutlass3_gemm(dtype):
for i in range(1):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')