First baseline kernel.
This commit is contained in:
parent
9cab14a3ff
commit
c1bfb210c5
|
@ -1429,7 +1429,7 @@ def cutlass3_gemm(
|
|||
|
||||
m = sB[1]
|
||||
k = sB[0]
|
||||
lda = B.stride()[(1 if transposed_B else 0)]
|
||||
lda = B.stride()[0]
|
||||
ldc = sB[1]
|
||||
elif len(sB) == 3:
|
||||
# special case
|
||||
|
@ -1446,7 +1446,7 @@ def cutlass3_gemm(
|
|||
n = sA[2]
|
||||
k = sB[0] * sB[1]
|
||||
|
||||
lda = m
|
||||
lda = n
|
||||
ldb = sA[2]
|
||||
ldc = m
|
||||
|
||||
|
@ -1454,7 +1454,7 @@ def cutlass3_gemm(
|
|||
|
||||
# B^T @ A^T = C^T
|
||||
# [km, nk -> mn]
|
||||
lda = ldb = ldc = 1
|
||||
#lda = ldb = ldc = 1
|
||||
#lda = 1
|
||||
#print(m, n, k, lda, ldb, ldc)
|
||||
is_on_gpu([B, A, out])
|
||||
|
@ -1466,7 +1466,7 @@ def cutlass3_gemm(
|
|||
ldc = ct.c_int32(ldc)
|
||||
alpha = ct.c_float(1.0)
|
||||
beta = ct.c_float(0.0)
|
||||
lib.ccutlass_gemm(m, n, k, alpha, get_ptr(A), ldb, get_ptr(B), lda, beta, get_ptr(out), ldc)
|
||||
lib.ccutlass_gemm(m, n, k, alpha, get_ptr(A), lda, get_ptr(B), ldb, beta, get_ptr(out), ldc)
|
||||
|
||||
return out
|
||||
|
||||
|
|
101
csrc/kernels.cu
101
csrc/kernels.cu
|
@ -2947,9 +2947,11 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
|
|||
//// 9. write outputs to matmul output matrix
|
||||
//}
|
||||
|
||||
|
||||
#define ROWS 2
|
||||
__global__ void gemm_device(int M, int N, int K,
|
||||
float const* A,
|
||||
float const* B,
|
||||
float* B,
|
||||
float * out, int lda, int ldb, int ldc,
|
||||
float alpha, float beta)
|
||||
{
|
||||
|
@ -2958,29 +2960,106 @@ __global__ void gemm_device(int M, int N, int K,
|
|||
// 2. Dequantize B
|
||||
// 3. Fetch data from A and multiply
|
||||
|
||||
typedef cub::BlockLoad<float, 256 , 1, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadA;
|
||||
__shared__ typename LoadA::TempStorage loada;
|
||||
float dataA[1];
|
||||
typedef cub::BlockLoad<float, 256 , 4, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadA;
|
||||
//__shared__ typename LoadA::TempStorage loada;
|
||||
typedef cub::BlockLoad<float, 256 , 4, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadB;
|
||||
//__shared__ typename LoadB::TempStorage loadb;
|
||||
typedef cub::BlockReduce<float, 256> BlockReduce;
|
||||
// Allocate shared memory for BlockReduce
|
||||
//__shared__ typename BlockReduce::TempStorage reduce;
|
||||
|
||||
__shared__ union {
|
||||
typename BlockReduce::TempStorage reduce;
|
||||
typename LoadB::TempStorage loadb;
|
||||
typename LoadA::TempStorage loada;
|
||||
} temp_storage;
|
||||
|
||||
|
||||
float dataA[4];
|
||||
float local_B[4];
|
||||
float local_accC[ROWS];
|
||||
int valid_items = 0;
|
||||
const int warp_id = threadIdx.x/32;
|
||||
const int warp_lane = threadIdx.x % 32;
|
||||
const int col_offset = blockIdx.x * 8;
|
||||
|
||||
__shared__ float[16*256] tileA;
|
||||
__shared__ float tileA[ROWS*1024];
|
||||
__shared__ float accumulatorC[ROWS*8];
|
||||
|
||||
//#pragma unroll 8
|
||||
//for(int i = 0; i < 8; i++)
|
||||
// tileA[threadIdx.x + (i*256)] = 0.0f;
|
||||
//__syncthreads();
|
||||
if(threadIdx.x < 64)
|
||||
accumulatorC[threadIdx.x] = 0.0f;
|
||||
__syncthreads();
|
||||
|
||||
|
||||
for(int idxA = 0; idxA < M*K; idxA+= 256)
|
||||
for(int inner_idx = 0; inner_idx < K; inner_idx+= 1024)
|
||||
{
|
||||
valid_items = M*K - idxA > 256 ? 256 : M*K - idxA;
|
||||
valid_items = K - inner_idx > 1024 ? 1024 : K - inner_idx;
|
||||
int baserow = 0;
|
||||
for(int row = baserow; row < baserow+16 && row < M + ; row++)
|
||||
for(int row = baserow; row < (baserow+ROWS) && row < N; row++)
|
||||
{
|
||||
LoadA(loada).Load(&(A[(row*lda) + i]), dataA, valid_items, 0.0f);
|
||||
tileA[row*256 + threadIdx.x] = dataA[0];
|
||||
LoadA(temp_storage.loada).Load(&(A[(row*K) + inner_idx]), dataA, valid_items, 0.0f);
|
||||
|
||||
#pragma unroll 4
|
||||
for(int k = 0; k < 4; k++)
|
||||
tileA[row*1024 + threadIdx.x + (k*blockDim.x)] = dataA[k];
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
baserow += 16;
|
||||
baserow += ROWS;
|
||||
|
||||
// load 16 columns from B at a time. B is transposed, so its like loading rows
|
||||
// each warp loads one row
|
||||
// each thread loads 128 byte
|
||||
|
||||
// col: inner_idx + warp_lane
|
||||
// row: ldb*(offset + warp_id)
|
||||
for(int col = 0; col < 8 && (col_offset + col) < M; col++)
|
||||
{
|
||||
int colB = col_offset + col;
|
||||
|
||||
for(int k = 0; k < ROWS; k++)
|
||||
local_accC[k] = 0.0f;
|
||||
|
||||
int base_idxB = ldb*colB;
|
||||
valid_items = K - inner_idx > 1024 ? 1024 : K - inner_idx;
|
||||
LoadB(temp_storage.loadb).Load(&(B[base_idxB + inner_idx]), local_B, valid_items, 0.0f);
|
||||
__syncthreads();
|
||||
|
||||
for(int row = 0; row < ROWS && row < N; row++)
|
||||
{
|
||||
#pragma unroll 4
|
||||
for(int k = 0; k < 4; k++)
|
||||
{
|
||||
int idxA = row*1024 + threadIdx.x + (blockDim.x*k);
|
||||
local_accC[row] += tileA[idxA]*local_B[k];
|
||||
}
|
||||
|
||||
local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum());
|
||||
if(threadIdx.x == 0)
|
||||
atomicAdd(&accumulatorC[row*8 + col], local_accC[row]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(int row = 0; row < ROWS && row < N; row++)
|
||||
{
|
||||
int out_idx = ldc*row + col_offset;
|
||||
|
||||
//if(threadIdx.x < 8)
|
||||
// if(accumulatorC[row*8 + threadIdx.x] != 0.0)
|
||||
// printf("%i %i %i %i %f idx %i %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx, blockIdx.x);
|
||||
|
||||
if(threadIdx.x < 8 && (col_offset + threadIdx.x) < M)
|
||||
{
|
||||
//printf("%i %i %i %i %f idx %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx);
|
||||
out[out_idx + threadIdx.x] = accumulatorC[row*8 + threadIdx.x];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -140,7 +140,7 @@ __global__ void with_staging_unified(float const* global_in, float * global_out,
|
|||
|
||||
__global__ void gemm_device(int M, int N, int K,
|
||||
float const* A,
|
||||
float const* B,
|
||||
float * B,
|
||||
float * out, int lda, int ldb, int ldc,
|
||||
float alpha, float beta);
|
||||
|
||||
|
|
13
csrc/ops.cu
13
csrc/ops.cu
|
@ -669,8 +669,6 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size)
|
|||
int threads = 256;
|
||||
int num_blocks = (n+(256*batch_size)+1)/(batch_size*256);
|
||||
|
||||
printf("%i %i\n", num_blocks, batch_size);
|
||||
|
||||
with_staging_unified<2><<<num_blocks, threads>>>(A, B, n, batch_size);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
@ -680,15 +678,22 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size)
|
|||
void gemm_host(int m, int n, int k,
|
||||
float alpha,
|
||||
float const* A, int lda,
|
||||
float const* B, int ldb,
|
||||
float * B, int ldb,
|
||||
float beta,
|
||||
float * C, int ldc)
|
||||
{
|
||||
|
||||
dim3 dimBlock(256);
|
||||
int num_blocks = (n+31)/32;
|
||||
int num_blocks = (m+7)/8;
|
||||
|
||||
cout << num_blocks << endl;
|
||||
cout << lda << endl;
|
||||
cout << ldb << endl;
|
||||
cout << ldc << endl;
|
||||
|
||||
cout << m << endl;
|
||||
cout << n << endl;
|
||||
cout << k << endl;
|
||||
gemm_device
|
||||
<<< num_blocks, dimBlock, 0, 0 >>>
|
||||
(m, n, k,
|
||||
|
|
|
@ -193,7 +193,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows
|
|||
void gemm_host(int m, int n, int k,
|
||||
float alpha,
|
||||
float const* A, int ldA,
|
||||
float const* B, int ldB,
|
||||
float * B, int ldB,
|
||||
float beta,
|
||||
float * C, int ldC);
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ void
|
|||
cppgemm(int m, int n, int k,
|
||||
float alpha,
|
||||
float const* A, int ldA,
|
||||
float const* B, int ldB,
|
||||
float * B, int ldB,
|
||||
float beta,
|
||||
float * C, int ldC)
|
||||
{ gemm_host(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);}
|
||||
|
@ -320,7 +320,7 @@ extern "C"
|
|||
void ccutlass_gemm(int m, int n, int k,
|
||||
float alpha,
|
||||
float const* A, int ldA,
|
||||
float const* B, int ldB,
|
||||
float * B, int ldB,
|
||||
float beta,
|
||||
float * C, int ldC)
|
||||
{ cppgemm(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);}
|
||||
|
|
|
@ -2353,17 +2353,19 @@ def test_normal_map_tree():
|
|||
|
||||
|
||||
def test_cutlass3_gemm():
|
||||
#A = torch.rand(2, 2).cuda()
|
||||
#B = torch.rand(2, 2).cuda()
|
||||
A = torch.arange(4).reshape(2, 2).float().cuda().contiguous()
|
||||
B = torch.ones(2, 2).float().cuda()
|
||||
A = torch.rand(2, 4092).cuda()
|
||||
B = torch.rand(4*4092, 4092).cuda()
|
||||
|
||||
print('')
|
||||
print(A)
|
||||
print(B)
|
||||
#print('')
|
||||
#print(A)
|
||||
#print(B.t())
|
||||
|
||||
C1 = torch.matmul(A, B)
|
||||
C2 = F.cutlass3_gemm(A, B)
|
||||
C1 = torch.matmul(A, B.t())
|
||||
C2 = F.cutlass3_gemm(A, B.t())
|
||||
#print(C1)
|
||||
#print(C2)
|
||||
|
||||
torch.testing.assert_close(C1, C2)
|
||||
|
||||
|
||||
def test_pipeline_func():
|
||||
|
|
Loading…
Reference in New Issue
Block a user