8x32 240 6 warps.

This commit is contained in:
Tim Dettmers 2023-05-01 16:38:09 -07:00
parent 3d4a2eadd3
commit 7bfa09d0fc
2 changed files with 31 additions and 25 deletions

View File

@ -3041,7 +3041,7 @@ template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_l
} }
} }
#define WARPS 4 #define WARPS 6
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
{ {
@ -3052,26 +3052,26 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
//typedef cub::BlockReduce<T, THREADS> BlockReduce; //typedef cub::BlockReduce<T, THREADS> BlockReduce;
//// Allocate shared memory for BlockReduce //// Allocate shared memory for BlockReduce
//__shared__ typename BlockReduce::TempStorage reduce; //__shared__ typename BlockReduce::TempStorage reduce;
int col_offset = blockIdx.x *16; int col_offset = blockIdx.x *32;
const int warp_id = threadIdx.x / 32; const int warp_id = threadIdx.x / 32;
const int half_warp_id = threadIdx.x / 16; const int half_warp_id = threadIdx.x / 16;
const int half_warp_lane = threadIdx.x % 16; const int half_warp_lane = threadIdx.x % 16;
const int batch_size_warps = (WARPS-1)*2; const int batch_size_warps = (WARPS-1)*2;
T local_A[1]; T local_A[1];
T local_B[16]; T local_B[32];
const int a_tile_offset = (16*16 + 16); const int a_tile_offset = (8*16 + 16);
const int b_tile_offset = (16*16 + 16); const int b_tile_offset = (16*32 + 16);
const int c_tile_offset = 16*16 + 24; const int c_tile_offset = 8*32 + 24;
__shared__ T smem_A[2*batch_size_warps*16*16 + (2*16*(batch_size_warps-1))]; __shared__ T smem_A[2*batch_size_warps*8*16 + (2*16*(batch_size_warps-1))];
__shared__ T smem_B[2*batch_size_warps*16*16 + (2*16*(batch_size_warps-1))]; __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))];
__shared__ T smem_C[16*16]; __shared__ T smem_C[8*32];
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> a_frag; wmma::fragment<wmma::matrix_a, 8, 32, 16, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::col_major> b_frag; wmma::fragment<wmma::matrix_b, 8, 32, 16, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, 16, 16, 16, half> c_frag; wmma::fragment<wmma::accumulator, 8, 32, 16, half> c_frag;
wmma::fill_fragment(c_frag, 0.0f); wmma::fill_fragment(c_frag, 0.0f);
@ -3082,7 +3082,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
//for(int i = threadIdx.x; i < 16*16*WARPS; i+=blockDim.x) //for(int i = threadIdx.x; i < 16*16*WARPS; i+=blockDim.x)
// smem_B[i] = T(0); // smem_B[i] = T(0);
for(int i = threadIdx.x; i < 16*16; i+=blockDim.x) for(int i = threadIdx.x; i < 8*32; i+=blockDim.x)
smem_C[i] = T(0); smem_C[i] = T(0);
__syncthreads(); __syncthreads();
@ -3099,14 +3099,14 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
{ {
local_A[0] = A[idx]; local_A[0] = A[idx];
#pragma unroll 16 #pragma unroll 32
for(int col = 0; col < 16; col++) for(int col = 0; col < 32; col++)
local_B[col] = B[(col_offset+col)*ldb+idx]; local_B[col] = B[(col_offset+col)*ldb+idx];
smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = local_A[0]; smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = local_A[0];
#pragma unroll 16 #pragma unroll 32
for(int col = 0; col < 16; col++) for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = local_B[col]; smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = local_B[col];
} }
ticktock = ticktock == 0 ? 1 : 0; ticktock = ticktock == 0 ? 1 : 0;
@ -3120,14 +3120,14 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
{ {
local_A[0] = A[idx]; local_A[0] = A[idx];
#pragma unroll 16 #pragma unroll 32
for(int col = 0; col < 16; col++) for(int col = 0; col < 32; col++)
local_B[col] = B[(col_offset+col)*ldb+idx]; local_B[col] = B[(col_offset+col)*ldb+idx];
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
#pragma unroll 16 #pragma unroll 32
for(int col = 0; col < 16; col++) for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
} }
ticktock = ticktock == 0 ? 1 : 0; ticktock = ticktock == 0 ? 1 : 0;
@ -3143,7 +3143,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
// 129 mu // 129 mu
if(warp_id == (WARPS-1)) if(warp_id == (WARPS-1))
wmma::store_matrix_sync(&(smem_C[0]), c_frag, 16, wmma::mem_row_major); wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major);
__syncthreads(); __syncthreads();
//if(threadIdx.x >= 16){ return; } //if(threadIdx.x >= 16){ return; }
@ -3185,7 +3185,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
//if(threadIdx.x < 8 && col_offset + threadIdx.x < M) //if(threadIdx.x < 8 && col_offset + threadIdx.x < M)
//out[col_offset + threadIdx.x ] = smem_C[threadIdx.x]; //out[col_offset + threadIdx.x ] = smem_C[threadIdx.x];
if(threadIdx.x < 16 && col_offset + threadIdx.x < M) if(threadIdx.x < 32 && col_offset + threadIdx.x < M)
out[col_offset + threadIdx.x] = smem_C[threadIdx.x]; out[col_offset + threadIdx.x] = smem_C[threadIdx.x];
} }
@ -3470,18 +3470,22 @@ __global__ void with_staging_unified(float const* global_in, float * global_out,
// these are not used and make no sense, but the compiler needs them // these are not used and make no sense, but the compiler needs them
//template __global__ void gemm_device<float, 16, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); //template __global__ void gemm_device<float, 16, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 256>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device<half, 32, 256>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 192>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device<half, 32, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
//template __global__ void gemm_device<float, 16, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); //template __global__ void gemm_device<float, 16, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 32>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device<half, 32, 32>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 64>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device<half, 32, 64>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 96>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
// these are not used and make no sense, but the compiler needs them // these are not used and make no sense, but the compiler needs them
//template __global__ void gemm_device<float, 32, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); //template __global__ void gemm_device<float, 32, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 256>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device<half, 16, 256>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 192>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device<half, 16, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
//template __global__ void gemm_device<float, 32, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); //template __global__ void gemm_device<float, 32, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 32>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device<half, 16, 32>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 64>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device<half, 16, 64>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 96>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);

View File

@ -678,7 +678,7 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size)
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits)
{ {
int num_blocks = (m+15)/16; int num_blocks = (m+31)/32;
cout << num_blocks << endl; cout << num_blocks << endl;
cout << lda << endl; cout << lda << endl;
@ -693,7 +693,9 @@ template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out
//gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
if(bits == 16) if(bits == 16)
//gemm_device<T, 16, 256><<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<T, 16, 256><<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
gemm_device<T, 16, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); gemm_device<T, 16, 192><<< num_blocks, 192, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
//gemm_device<T, 16, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
//gemm_device<T, 16, 96><<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
//gemm_device<T, 16, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<T, 16, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
//gemm_device<T, 16, 64><<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<T, 16, 64><<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
} }