Slow tensor core solution.

This commit is contained in:
Tim Dettmers 2023-04-30 17:43:02 -07:00
parent 21723f796a
commit ad07d254fb
4 changed files with 160 additions and 48 deletions

View File

@ -14,6 +14,7 @@
#include <math_constants.h>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <mma.h>
#include <cooperative_groups/memcpy_async.h>
#include <cuda/pipeline>
@ -23,6 +24,8 @@
#define NUM 4
#define NUM_BLOCK 4096
using namespace nvcuda;
// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
__device__ float atomicMax(float* address, float val) {
int* address_as_i = reinterpret_cast<int*>(address);
@ -3041,62 +3044,164 @@ template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_l
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
{
typedef cub::BlockReduce<T, THREADS> BlockReduce;
__shared__ typename BlockReduce::TempStorage reduce;
int col_offset = blockIdx.x *8;
typedef cub::WarpReduce<half> WarpReduce;
// Allocate WarpReduce shared memory for one warp
//__shared__ typename WarpReduce::TempStorage temp_storage;
T local_A[128/BITS];
T local_B[128/BITS];
//typedef cub::BlockReduce<T, THREADS> BlockReduce;
//// Allocate shared memory for BlockReduce
//__shared__ typename BlockReduce::TempStorage reduce;
int col_offset = blockIdx.x *8;
const int warp_id = threadIdx.x / 32;
const int warp_lane = threadIdx.x % 32;
T local_A[64/BITS];
T local_B[64/BITS];
T local_C[8];
__shared__ T smem_C[8];
__shared__ T smem_A[4*32*16];
__shared__ T smem_B[4*16*8];
__shared__ T smem_C[4*32*8];
if(threadIdx.x < 8)
smem_C[threadIdx.x] = T(0);
wmma::fragment<wmma::matrix_a, 32, 8, 16, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, 32, 8, 16, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, 32, 8, 16, half> c_frag;
wmma::fill_fragment(c_frag, 0.0f);
for(int i = threadIdx.x; i < 32*16*4; i+=blockDim.x)
smem_A[i] = T(0);
for(int i = threadIdx.x; i < 32*8*4; i+=blockDim.x)
smem_B[i] = T(0);
for(int i = threadIdx.x; i < 32*8*THREADS/32; i+=blockDim.x)
smem_C[i] = T(0);
__syncthreads();
#pragma unroll 8
for(int k = 0; k < 8; k++)
local_C[k] = T(0);
for(int idx = threadIdx.x*128/BITS; idx < K; idx+=blockDim.x*128/BITS)
int block_idx = 0;
//for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x)
for(int base_idx = 0; base_idx < K; base_idx+=64)
{
vector_load<T, int4, 128/BITS>(local_A, A, idx, idx, K);
for(int col = 0; col < 8; col++)
int tidx = threadIdx.x*4;
if(base_idx % (4*blockDim.x) == 0)
{
int offset_B = (col_offset+col)*ldb;
vector_load<T, int4, 128/BITS>(local_B, B, offset_B+idx, idx, K);
#pragma unroll 128/BITS
for(int k = 0; k < 128/BITS; k++)
local_C[col] += local_A[k]*local_B[k];
vector_load<T, int2, 64/BITS>(local_A, A, base_idx+tidx, base_idx+tidx, K); // 54 mu
block_idx = 0;
}
}
#pragma unroll 8
for(int k = 0; k < 8; k++)
{
local_C[k] = BlockReduce(reduce).Reduce(local_C[k], cub::Sum());
for(int k = 0; k < 4; k++)
{
if((threadIdx.x >= block_idx*16) && (threadIdx.x < (block_idx+1)*16))
smem_A[(threadIdx.x % 16) + (32*16*k)] = local_A[k]; // 54 mu
}
block_idx += 1;
// 4 warps, 1 warps loads in total 4*32=64 values -> 4 columns at a time
// we need 8 columns, so 2 loads and smem stores
// we need a half-warp to load one column at a time
for(int j = 0; j < 2; j++)
{
int col = warp_id + (j*4);
int offset_B = (col_offset+col)*ldb;
vector_load<T, int2, 64/BITS>(local_B, B, offset_B+base_idx+warp_lane*4, base_idx+warp_lane*4, K); // 171 mu
//#pragma unroll 4
//for(int k = 0; k < 4; k++)
// if((float)local_B[k] != 0.0)
// printf("%i %i %i %i %f\n", j, warp_id, warp_lane, k, (float)local_B[k]);
// load and store is different
// we wnat to load 64 consequitive values with one warp
// but we need to store those across 4 fragments since
// the max column width is 16.
// each 16 values a new tile for each warp
//int tile_idx = warp_lane/16;
#pragma unroll 4
for(int k = 0; k < 4; k++)
smem_B[(warp_lane % 16) + (col*16) + (k*16*8)] = local_B[k]; // 171 mu
}
__syncthreads();
//if(threadIdx.x == 0)
// for(int w = 0; w < 4; w++)
// for(int trow = 0; trow < 32; trow++)
// for(int tcol = 0; tcol < 16; tcol++)
// if((float)smem_A[trow + tcol*32 + (w*32*16)] != 0.0)
// printf("A %i %i %i = %f\n", w, trow, tcol, (float) smem_B[trow + tcol*16]);
//if(threadIdx.x == 0)
// for(int w = 0; w < 4; w++)
// for(int trow = 0; trow < 16; trow++)
// for(int tcol = 0; tcol < 8; tcol++)
// if((float)smem_B[trow + tcol*16 + (w*16*8)] != 0.0)
// printf("B %i %i %i = %f\n", w, trow, tcol, (float) smem_B[trow + tcol*16]);
//__syncthreads();
wmma::load_matrix_sync(a_frag, &(smem_A[warp_id*32*16]), 16); // 111 mu
wmma::load_matrix_sync(b_frag, &(smem_B[warp_id*16*8]), 16); // 35 mu
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
if(threadIdx.x == 0)
{
#pragma unroll 8
for(int k = 0; k < 8; k++)
smem_C[k] = local_C[k];
}
else if(threadIdx.x >= 32)
// early return for unused warps
return;
// 129 mu
wmma::store_matrix_sync(&(smem_C[warp_id*32*8]), c_frag, 8, wmma::mem_row_major);
__syncthreads();
__syncwarp();
//if(threadIdx.x >= 16){ return; }
//printf("%i %f\n", threadIdx.x, (float)smem_C[threadIdx.x]);
//if(threadIdx.x < 32)
if(warp_lane < 8 && warp_id > 0)
//local_C[warp_lane] = smem_C[warp_lane + (warp_id*32*8)];
atomicAdd(&(smem_C[warp_lane]), smem_C[warp_lane + (warp_id*32*8)]);
__syncthreads();
//local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum());
//if(threadIdx.x == 0)
// for(int row = 0; row < 32; row++)
// {
// printf("row %i ", row);
// for(int id = 0; id < 4; id++)
// {
// printf(" id %i: ", id);
// for(int k = 0; k < 8; k++)
// printf("%f ", (float)smem_C[k + (row*8) + (id*32*8)]);
// printf("\n");
// }
// }
//__syncthreads();
//if((float)local_C[0] !=0.0f)
// printf("%i %i %f\n", warp_lane, warp_id, (float)local_C[0]);
//local_C[0] = WarpReduce(temp_storage).Sum(local_C[0]);
//__syncwarp();
////for(int i = threadIdx.x; i < 32*8; i+=blockDim.x)
////{
// if((float)local_C[0] !=0.0f)
// printf("%i %f\n", 0, (float)local_C[0]);
//}
//if(threadIdx.x < 8 && col_offset + threadIdx.x < M)
//out[col_offset + threadIdx.x ] = smem_C[threadIdx.x];
if(threadIdx.x < 8 && col_offset + threadIdx.x < M)
out[col_offset + threadIdx.x ] = smem_C[threadIdx.x];
out[col_offset + threadIdx.x] = smem_C[threadIdx.x];
}
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
@ -3378,12 +3483,16 @@ __global__ void with_staging_unified(float const* global_in, float * global_out,
// half alpha, half beta);
// these are not used and make no sense, but the compiler needs them
template __global__ void gemm_device<float, 16, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
//template __global__ void gemm_device<float, 16, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
//template __global__ void gemm_device<float, 16, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 32>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
// these are not used and make no sense, but the compiler needs them
template __global__ void gemm_device<float, 32, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
//template __global__ void gemm_device<float, 32, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
//template __global__ void gemm_device<float, 32, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 32>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);

View File

@ -678,7 +678,6 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size)
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits)
{
dim3 dimBlock(128);
int num_blocks = (m+7)/8;
cout << num_blocks << endl;
@ -689,16 +688,17 @@ template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out
cout << m << endl;
cout << n << endl;
cout << k << endl;
if(bits == 32)
gemm_device<T, 32, 128><<< num_blocks, dimBlock, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
else if(bits == 16)
gemm_device<T, 16, 128><<< num_blocks, dimBlock, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
//if(bits == 32)
//gemm_device<T, 32, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
//gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
if(bits == 16)
gemm_device<T, 16, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
//gemm_device<T, 16, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
}
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
{
dim3 dimBlock(128);
int num_blocks = (m+7)/8;
cout << num_blocks << endl;
@ -709,7 +709,8 @@ template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsi
cout << m << endl;
cout << n << endl;
cout << k << endl;
kgemm_4bit_inference<T, 128><<< num_blocks, dimBlock, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
kgemm_4bit_inference<T, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}
//==============================================================
@ -717,7 +718,7 @@ template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsi
//==============================================================
template void gemm_4bit_inference<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
template void extractOutliers<COL_AMPERE>(char * A, int *idx, char *out, int idx_size, int rows, int cols);

View File

@ -20,8 +20,8 @@ void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimat
void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles<half>(A, code, offset, n); }
void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
{ gemm_host<float>(M, N, K, A, B, out, lda, ldb, ldc, 32); }
//void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
//{ gemm_host<float>(M, N, K, A, B, out, lda, ldb, ldc, 32); }
void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc)
{ gemm_host<half>(M, N, K, A, B, out, lda, ldb, ldc, 16); }
@ -316,8 +316,8 @@ extern "C"
void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); }
void cpipeline_test(float *A, float *B, size_t n, size_t batch_size){ pipeline_test(A, B, n, batch_size); }
void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); }
//void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
//{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); }
void cgemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc)
{ gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); }

View File

@ -2358,6 +2358,8 @@ def test_cutlass3_gemm(dtype):
for i in range(1):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
A = torch.rand(1, 4096, dtype=dtype, device='cuda')
B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')