First baseline kernel.

This commit is contained in:
Tim Dettmers 2023-04-28 17:19:02 -07:00
parent 9cab14a3ff
commit c1bfb210c5
7 changed files with 118 additions and 32 deletions

View File

@ -1429,7 +1429,7 @@ def cutlass3_gemm(
m = sB[1]
k = sB[0]
lda = B.stride()[(1 if transposed_B else 0)]
lda = B.stride()[0]
ldc = sB[1]
elif len(sB) == 3:
# special case
@ -1446,7 +1446,7 @@ def cutlass3_gemm(
n = sA[2]
k = sB[0] * sB[1]
lda = m
lda = n
ldb = sA[2]
ldc = m
@ -1454,7 +1454,7 @@ def cutlass3_gemm(
# B^T @ A^T = C^T
# [km, nk -> mn]
lda = ldb = ldc = 1
#lda = ldb = ldc = 1
#lda = 1
#print(m, n, k, lda, ldb, ldc)
is_on_gpu([B, A, out])
@ -1466,7 +1466,7 @@ def cutlass3_gemm(
ldc = ct.c_int32(ldc)
alpha = ct.c_float(1.0)
beta = ct.c_float(0.0)
lib.ccutlass_gemm(m, n, k, alpha, get_ptr(A), ldb, get_ptr(B), lda, beta, get_ptr(out), ldc)
lib.ccutlass_gemm(m, n, k, alpha, get_ptr(A), lda, get_ptr(B), ldb, beta, get_ptr(out), ldc)
return out

View File

@ -2947,9 +2947,11 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
//// 9. write outputs to matmul output matrix
//}
#define ROWS 2
__global__ void gemm_device(int M, int N, int K,
float const* A,
float const* B,
float* B,
float * out, int lda, int ldb, int ldc,
float alpha, float beta)
{
@ -2958,29 +2960,106 @@ __global__ void gemm_device(int M, int N, int K,
// 2. Dequantize B
// 3. Fetch data from A and multiply
typedef cub::BlockLoad<float, 256 , 1, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadA;
__shared__ typename LoadA::TempStorage loada;
float dataA[1];
typedef cub::BlockLoad<float, 256 , 4, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadA;
//__shared__ typename LoadA::TempStorage loada;
typedef cub::BlockLoad<float, 256 , 4, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadB;
//__shared__ typename LoadB::TempStorage loadb;
typedef cub::BlockReduce<float, 256> BlockReduce;
// Allocate shared memory for BlockReduce
//__shared__ typename BlockReduce::TempStorage reduce;
__shared__ union {
typename BlockReduce::TempStorage reduce;
typename LoadB::TempStorage loadb;
typename LoadA::TempStorage loada;
} temp_storage;
float dataA[4];
float local_B[4];
float local_accC[ROWS];
int valid_items = 0;
const int warp_id = threadIdx.x/32;
const int warp_lane = threadIdx.x % 32;
const int col_offset = blockIdx.x * 8;
__shared__ float[16*256] tileA;
__shared__ float tileA[ROWS*1024];
__shared__ float accumulatorC[ROWS*8];
//#pragma unroll 8
//for(int i = 0; i < 8; i++)
// tileA[threadIdx.x + (i*256)] = 0.0f;
//__syncthreads();
if(threadIdx.x < 64)
accumulatorC[threadIdx.x] = 0.0f;
__syncthreads();
for(int idxA = 0; idxA < M*K; idxA+= 256)
for(int inner_idx = 0; inner_idx < K; inner_idx+= 1024)
{
valid_items = M*K - idxA > 256 ? 256 : M*K - idxA;
valid_items = K - inner_idx > 1024 ? 1024 : K - inner_idx;
int baserow = 0;
for(int row = baserow; row < baserow+16 && row < M + ; row++)
for(int row = baserow; row < (baserow+ROWS) && row < N; row++)
{
LoadA(loada).Load(&(A[(row*lda) + i]), dataA, valid_items, 0.0f);
tileA[row*256 + threadIdx.x] = dataA[0];
LoadA(temp_storage.loada).Load(&(A[(row*K) + inner_idx]), dataA, valid_items, 0.0f);
#pragma unroll 4
for(int k = 0; k < 4; k++)
tileA[row*1024 + threadIdx.x + (k*blockDim.x)] = dataA[k];
__syncthreads();
}
baserow += 16;
baserow += ROWS;
// load 16 columns from B at a time. B is transposed, so its like loading rows
// each warp loads one row
// each thread loads 128 byte
// col: inner_idx + warp_lane
// row: ldb*(offset + warp_id)
for(int col = 0; col < 8 && (col_offset + col) < M; col++)
{
int colB = col_offset + col;
for(int k = 0; k < ROWS; k++)
local_accC[k] = 0.0f;
int base_idxB = ldb*colB;
valid_items = K - inner_idx > 1024 ? 1024 : K - inner_idx;
LoadB(temp_storage.loadb).Load(&(B[base_idxB + inner_idx]), local_B, valid_items, 0.0f);
__syncthreads();
for(int row = 0; row < ROWS && row < N; row++)
{
#pragma unroll 4
for(int k = 0; k < 4; k++)
{
int idxA = row*1024 + threadIdx.x + (blockDim.x*k);
local_accC[row] += tileA[idxA]*local_B[k];
}
local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum());
if(threadIdx.x == 0)
atomicAdd(&accumulatorC[row*8 + col], local_accC[row]);
}
}
}
for(int row = 0; row < ROWS && row < N; row++)
{
int out_idx = ldc*row + col_offset;
//if(threadIdx.x < 8)
// if(accumulatorC[row*8 + threadIdx.x] != 0.0)
// printf("%i %i %i %i %f idx %i %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx, blockIdx.x);
if(threadIdx.x < 8 && (col_offset + threadIdx.x) < M)
{
//printf("%i %i %i %i %f idx %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx);
out[out_idx + threadIdx.x] = accumulatorC[row*8 + threadIdx.x];
}
}
}

View File

@ -140,7 +140,7 @@ __global__ void with_staging_unified(float const* global_in, float * global_out,
__global__ void gemm_device(int M, int N, int K,
float const* A,
float const* B,
float * B,
float * out, int lda, int ldb, int ldc,
float alpha, float beta);

View File

@ -669,8 +669,6 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size)
int threads = 256;
int num_blocks = (n+(256*batch_size)+1)/(batch_size*256);
printf("%i %i\n", num_blocks, batch_size);
with_staging_unified<2><<<num_blocks, threads>>>(A, B, n, batch_size);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
@ -680,15 +678,22 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size)
void gemm_host(int m, int n, int k,
float alpha,
float const* A, int lda,
float const* B, int ldb,
float * B, int ldb,
float beta,
float * C, int ldc)
{
dim3 dimBlock(256);
int num_blocks = (n+31)/32;
int num_blocks = (m+7)/8;
cout << num_blocks << endl;
cout << lda << endl;
cout << ldb << endl;
cout << ldc << endl;
cout << m << endl;
cout << n << endl;
cout << k << endl;
gemm_device
<<< num_blocks, dimBlock, 0, 0 >>>
(m, n, k,

View File

@ -193,7 +193,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows
void gemm_host(int m, int n, int k,
float alpha,
float const* A, int ldA,
float const* B, int ldB,
float * B, int ldB,
float beta,
float * C, int ldC);

View File

@ -24,7 +24,7 @@ void
cppgemm(int m, int n, int k,
float alpha,
float const* A, int ldA,
float const* B, int ldB,
float * B, int ldB,
float beta,
float * C, int ldC)
{ gemm_host(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);}
@ -320,7 +320,7 @@ extern "C"
void ccutlass_gemm(int m, int n, int k,
float alpha,
float const* A, int ldA,
float const* B, int ldB,
float * B, int ldB,
float beta,
float * C, int ldC)
{ cppgemm(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);}

View File

@ -2353,17 +2353,19 @@ def test_normal_map_tree():
def test_cutlass3_gemm():
#A = torch.rand(2, 2).cuda()
#B = torch.rand(2, 2).cuda()
A = torch.arange(4).reshape(2, 2).float().cuda().contiguous()
B = torch.ones(2, 2).float().cuda()
A = torch.rand(2, 4092).cuda()
B = torch.rand(4*4092, 4092).cuda()
print('')
print(A)
print(B)
#print('')
#print(A)
#print(B.t())
C1 = torch.matmul(A, B)
C2 = F.cutlass3_gemm(A, B)
C1 = torch.matmul(A, B.t())
C2 = F.cutlass3_gemm(A, B.t())
#print(C1)
#print(C2)
torch.testing.assert_close(C1, C2)
def test_pipeline_func():