8x32 240 6 warps.
This commit is contained in:
parent
3d4a2eadd3
commit
7bfa09d0fc
|
@ -3041,7 +3041,7 @@ template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_l
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define WARPS 4
|
#define WARPS 6
|
||||||
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
|
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
|
||||||
{
|
{
|
||||||
|
|
||||||
|
@ -3052,26 +3052,26 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
||||||
//typedef cub::BlockReduce<T, THREADS> BlockReduce;
|
//typedef cub::BlockReduce<T, THREADS> BlockReduce;
|
||||||
//// Allocate shared memory for BlockReduce
|
//// Allocate shared memory for BlockReduce
|
||||||
//__shared__ typename BlockReduce::TempStorage reduce;
|
//__shared__ typename BlockReduce::TempStorage reduce;
|
||||||
int col_offset = blockIdx.x *16;
|
int col_offset = blockIdx.x *32;
|
||||||
const int warp_id = threadIdx.x / 32;
|
const int warp_id = threadIdx.x / 32;
|
||||||
const int half_warp_id = threadIdx.x / 16;
|
const int half_warp_id = threadIdx.x / 16;
|
||||||
const int half_warp_lane = threadIdx.x % 16;
|
const int half_warp_lane = threadIdx.x % 16;
|
||||||
const int batch_size_warps = (WARPS-1)*2;
|
const int batch_size_warps = (WARPS-1)*2;
|
||||||
|
|
||||||
T local_A[1];
|
T local_A[1];
|
||||||
T local_B[16];
|
T local_B[32];
|
||||||
|
|
||||||
const int a_tile_offset = (16*16 + 16);
|
const int a_tile_offset = (8*16 + 16);
|
||||||
const int b_tile_offset = (16*16 + 16);
|
const int b_tile_offset = (16*32 + 16);
|
||||||
const int c_tile_offset = 16*16 + 24;
|
const int c_tile_offset = 8*32 + 24;
|
||||||
|
|
||||||
__shared__ T smem_A[2*batch_size_warps*16*16 + (2*16*(batch_size_warps-1))];
|
__shared__ T smem_A[2*batch_size_warps*8*16 + (2*16*(batch_size_warps-1))];
|
||||||
__shared__ T smem_B[2*batch_size_warps*16*16 + (2*16*(batch_size_warps-1))];
|
__shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))];
|
||||||
__shared__ T smem_C[16*16];
|
__shared__ T smem_C[8*32];
|
||||||
|
|
||||||
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> a_frag;
|
wmma::fragment<wmma::matrix_a, 8, 32, 16, half, wmma::row_major> a_frag;
|
||||||
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::col_major> b_frag;
|
wmma::fragment<wmma::matrix_b, 8, 32, 16, half, wmma::col_major> b_frag;
|
||||||
wmma::fragment<wmma::accumulator, 16, 16, 16, half> c_frag;
|
wmma::fragment<wmma::accumulator, 8, 32, 16, half> c_frag;
|
||||||
|
|
||||||
wmma::fill_fragment(c_frag, 0.0f);
|
wmma::fill_fragment(c_frag, 0.0f);
|
||||||
|
|
||||||
|
@ -3082,7 +3082,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
||||||
//for(int i = threadIdx.x; i < 16*16*WARPS; i+=blockDim.x)
|
//for(int i = threadIdx.x; i < 16*16*WARPS; i+=blockDim.x)
|
||||||
// smem_B[i] = T(0);
|
// smem_B[i] = T(0);
|
||||||
|
|
||||||
for(int i = threadIdx.x; i < 16*16; i+=blockDim.x)
|
for(int i = threadIdx.x; i < 8*32; i+=blockDim.x)
|
||||||
smem_C[i] = T(0);
|
smem_C[i] = T(0);
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
@ -3099,14 +3099,14 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
||||||
{
|
{
|
||||||
local_A[0] = A[idx];
|
local_A[0] = A[idx];
|
||||||
|
|
||||||
#pragma unroll 16
|
#pragma unroll 32
|
||||||
for(int col = 0; col < 16; col++)
|
for(int col = 0; col < 32; col++)
|
||||||
local_B[col] = B[(col_offset+col)*ldb+idx];
|
local_B[col] = B[(col_offset+col)*ldb+idx];
|
||||||
|
|
||||||
smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = local_A[0];
|
smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = local_A[0];
|
||||||
|
|
||||||
#pragma unroll 16
|
#pragma unroll 32
|
||||||
for(int col = 0; col < 16; col++)
|
for(int col = 0; col < 32; col++)
|
||||||
smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = local_B[col];
|
smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = local_B[col];
|
||||||
}
|
}
|
||||||
ticktock = ticktock == 0 ? 1 : 0;
|
ticktock = ticktock == 0 ? 1 : 0;
|
||||||
|
@ -3120,14 +3120,14 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
||||||
{
|
{
|
||||||
local_A[0] = A[idx];
|
local_A[0] = A[idx];
|
||||||
|
|
||||||
#pragma unroll 16
|
#pragma unroll 32
|
||||||
for(int col = 0; col < 16; col++)
|
for(int col = 0; col < 32; col++)
|
||||||
local_B[col] = B[(col_offset+col)*ldb+idx];
|
local_B[col] = B[(col_offset+col)*ldb+idx];
|
||||||
|
|
||||||
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
|
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
|
||||||
|
|
||||||
#pragma unroll 16
|
#pragma unroll 32
|
||||||
for(int col = 0; col < 16; col++)
|
for(int col = 0; col < 32; col++)
|
||||||
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
|
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
|
||||||
}
|
}
|
||||||
ticktock = ticktock == 0 ? 1 : 0;
|
ticktock = ticktock == 0 ? 1 : 0;
|
||||||
|
@ -3143,7 +3143,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
||||||
|
|
||||||
// 129 mu
|
// 129 mu
|
||||||
if(warp_id == (WARPS-1))
|
if(warp_id == (WARPS-1))
|
||||||
wmma::store_matrix_sync(&(smem_C[0]), c_frag, 16, wmma::mem_row_major);
|
wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major);
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
//if(threadIdx.x >= 16){ return; }
|
//if(threadIdx.x >= 16){ return; }
|
||||||
|
@ -3185,7 +3185,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
||||||
|
|
||||||
//if(threadIdx.x < 8 && col_offset + threadIdx.x < M)
|
//if(threadIdx.x < 8 && col_offset + threadIdx.x < M)
|
||||||
//out[col_offset + threadIdx.x ] = smem_C[threadIdx.x];
|
//out[col_offset + threadIdx.x ] = smem_C[threadIdx.x];
|
||||||
if(threadIdx.x < 16 && col_offset + threadIdx.x < M)
|
if(threadIdx.x < 32 && col_offset + threadIdx.x < M)
|
||||||
out[col_offset + threadIdx.x] = smem_C[threadIdx.x];
|
out[col_offset + threadIdx.x] = smem_C[threadIdx.x];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3470,18 +3470,22 @@ __global__ void with_staging_unified(float const* global_in, float * global_out,
|
||||||
// these are not used and make no sense, but the compiler needs them
|
// these are not used and make no sense, but the compiler needs them
|
||||||
//template __global__ void gemm_device<float, 16, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
|
//template __global__ void gemm_device<float, 16, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
|
||||||
template __global__ void gemm_device<half, 32, 256>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
template __global__ void gemm_device<half, 32, 256>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||||
|
template __global__ void gemm_device<half, 32, 192>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||||
template __global__ void gemm_device<half, 32, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
template __global__ void gemm_device<half, 32, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||||
//template __global__ void gemm_device<float, 16, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
|
//template __global__ void gemm_device<float, 16, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
|
||||||
template __global__ void gemm_device<half, 32, 32>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
template __global__ void gemm_device<half, 32, 32>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||||
template __global__ void gemm_device<half, 32, 64>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
template __global__ void gemm_device<half, 32, 64>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||||
|
template __global__ void gemm_device<half, 32, 96>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||||
// these are not used and make no sense, but the compiler needs them
|
// these are not used and make no sense, but the compiler needs them
|
||||||
|
|
||||||
//template __global__ void gemm_device<float, 32, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
|
//template __global__ void gemm_device<float, 32, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
|
||||||
template __global__ void gemm_device<half, 16, 256>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
template __global__ void gemm_device<half, 16, 256>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||||
|
template __global__ void gemm_device<half, 16, 192>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||||
template __global__ void gemm_device<half, 16, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
template __global__ void gemm_device<half, 16, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||||
//template __global__ void gemm_device<float, 32, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
|
//template __global__ void gemm_device<float, 32, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
|
||||||
template __global__ void gemm_device<half, 16, 32>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
template __global__ void gemm_device<half, 16, 32>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||||
template __global__ void gemm_device<half, 16, 64>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
template __global__ void gemm_device<half, 16, 64>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||||
|
template __global__ void gemm_device<half, 16, 96>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||||
|
|
||||||
template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||||
|
|
||||||
|
|
|
@ -678,7 +678,7 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size)
|
||||||
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits)
|
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits)
|
||||||
{
|
{
|
||||||
|
|
||||||
int num_blocks = (m+15)/16;
|
int num_blocks = (m+31)/32;
|
||||||
|
|
||||||
cout << num_blocks << endl;
|
cout << num_blocks << endl;
|
||||||
cout << lda << endl;
|
cout << lda << endl;
|
||||||
|
@ -693,7 +693,9 @@ template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out
|
||||||
//gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
//gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||||
if(bits == 16)
|
if(bits == 16)
|
||||||
//gemm_device<T, 16, 256><<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
//gemm_device<T, 16, 256><<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||||
gemm_device<T, 16, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
gemm_device<T, 16, 192><<< num_blocks, 192, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||||
|
//gemm_device<T, 16, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||||
|
//gemm_device<T, 16, 96><<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||||
//gemm_device<T, 16, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
//gemm_device<T, 16, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||||
//gemm_device<T, 16, 64><<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
//gemm_device<T, 16, 64><<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user