16x16 240.

This commit is contained in:
Tim Dettmers 2023-05-01 16:23:45 -07:00
parent 7cc8ff4727
commit 3d4a2eadd3
2 changed files with 27 additions and 27 deletions

View File

@ -3052,37 +3052,37 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
//typedef cub::BlockReduce<T, THREADS> BlockReduce;
//// Allocate shared memory for BlockReduce
//__shared__ typename BlockReduce::TempStorage reduce;
int col_offset = blockIdx.x *8;
int col_offset = blockIdx.x *16;
const int warp_id = threadIdx.x / 32;
const int half_warp_id = threadIdx.x / 16;
const int half_warp_lane = threadIdx.x % 16;
const int batch_size_warps = (WARPS-1)*2;
T local_A[1];
T local_B[8];
T local_B[16];
const int a_tile_offset = (32*16 + 16);
const int b_tile_offset = (16*8 + 16);
const int c_tile_offset = 32*8 + 24;
const int a_tile_offset = (16*16 + 16);
const int b_tile_offset = (16*16 + 16);
const int c_tile_offset = 16*16 + 24;
__shared__ T smem_A[2*batch_size_warps*32*16 + (2*16*(batch_size_warps-1))];
__shared__ T smem_B[2*batch_size_warps*16*8 + (2*16*(batch_size_warps-1))];
__shared__ T smem_C[32*8];
__shared__ T smem_A[2*batch_size_warps*16*16 + (2*16*(batch_size_warps-1))];
__shared__ T smem_B[2*batch_size_warps*16*16 + (2*16*(batch_size_warps-1))];
__shared__ T smem_C[16*16];
wmma::fragment<wmma::matrix_a, 32, 8, 16, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, 32, 8, 16, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, 32, 8, 16, half> c_frag;
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, 16, 16, 16, half> c_frag;
wmma::fill_fragment(c_frag, 0.0f);
for(int i = threadIdx.x; i < 32*16*WARPS; i+=blockDim.x)
smem_A[i] = T(0);
//for(int i = threadIdx.x; i < 16*16*WARPS; i+=blockDim.x)
// smem_A[i] = T(0);
for(int i = threadIdx.x; i < 32*8*WARPS; i+=blockDim.x)
smem_B[i] = T(0);
//for(int i = threadIdx.x; i < 16*16*WARPS; i+=blockDim.x)
// smem_B[i] = T(0);
for(int i = threadIdx.x; i < 32*8*WARPS; i+=blockDim.x)
for(int i = threadIdx.x; i < 16*16; i+=blockDim.x)
smem_C[i] = T(0);
__syncthreads();
@ -3099,14 +3099,14 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
{
local_A[0] = A[idx];
#pragma unroll 8
for(int col = 0; col < 8; col++)
#pragma unroll 16
for(int col = 0; col < 16; col++)
local_B[col] = B[(col_offset+col)*ldb+idx];
smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = local_A[0];
#pragma unroll 8
for(int col = 0; col < 8; col++)
#pragma unroll 16
for(int col = 0; col < 16; col++)
smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = local_B[col];
}
ticktock = ticktock == 0 ? 1 : 0;
@ -3120,14 +3120,14 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
{
local_A[0] = A[idx];
#pragma unroll 8
for(int col = 0; col < 8; col++)
#pragma unroll 16
for(int col = 0; col < 16; col++)
local_B[col] = B[(col_offset+col)*ldb+idx];
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
#pragma unroll 8
for(int col = 0; col < 8; col++)
#pragma unroll 16
for(int col = 0; col < 16; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
}
ticktock = ticktock == 0 ? 1 : 0;
@ -3143,7 +3143,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
// 129 mu
if(warp_id == (WARPS-1))
wmma::store_matrix_sync(&(smem_C[0]), c_frag, 8, wmma::mem_row_major);
wmma::store_matrix_sync(&(smem_C[0]), c_frag, 16, wmma::mem_row_major);
__syncthreads();
//if(threadIdx.x >= 16){ return; }
@ -3185,7 +3185,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
//if(threadIdx.x < 8 && col_offset + threadIdx.x < M)
//out[col_offset + threadIdx.x ] = smem_C[threadIdx.x];
if(threadIdx.x < 8 && col_offset + threadIdx.x < M)
if(threadIdx.x < 16 && col_offset + threadIdx.x < M)
out[col_offset + threadIdx.x] = smem_C[threadIdx.x];
}

View File

@ -678,7 +678,7 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size)
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits)
{
int num_blocks = (m+7)/8;
int num_blocks = (m+15)/16;
cout << num_blocks << endl;
cout << lda << endl;