64 threads, high smem, 434.

This commit is contained in:
Tim Dettmers 2023-04-30 18:55:12 -07:00
parent e01d4e033d
commit 30d03e0254
2 changed files with 26 additions and 25 deletions

View File

@ -3041,7 +3041,7 @@ template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_l
} }
} }
#define WARPS 1 #define WARPS 2
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
{ {
@ -3062,10 +3062,11 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
const int a_tile_offset = 32*16 + 16; const int a_tile_offset = 32*16 + 16;
const int b_tile_offset = 16*8 + 16; const int b_tile_offset = 16*8 + 16;
const int c_tile_offset = 32*8 + 24;
__shared__ T smem_A[WARPS*32*16*2 + (16*1)]; __shared__ T smem_A[WARPS*32*16*2 + (16*(WARPS-1))];
__shared__ T smem_B[WARPS*16*8*2 + (16*1)]; __shared__ T smem_B[WARPS*16*8*2 + (16*(WARPS-1))];
__shared__ T smem_C[WARPS*32*8]; __shared__ T smem_C[WARPS*32*8 + (24*(WARPS-1))];
wmma::fragment<wmma::matrix_a, 32, 8, 16, half, wmma::row_major> a_frag; wmma::fragment<wmma::matrix_a, 32, 8, 16, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, 32, 8, 16, half, wmma::col_major> b_frag; wmma::fragment<wmma::matrix_b, 32, 8, 16, half, wmma::col_major> b_frag;
@ -3092,46 +3093,45 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
//int block_idx = 0; //int block_idx = 0;
//for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x) //for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x)
for(int base_idx = 0; base_idx < K; base_idx+=32) for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x)
{ {
int idx = base_idx + threadIdx.x; int idx = base_idx + threadIdx.x;
if(idx >= K) if(idx >= K)
{ {
smem_A[threadIdx.x] = 0.0f; smem_A[threadIdx.x] = 0.0f;
//smem_B[threadIdx.x] = 0.0f; //smem_B[threadIdx.x] = 0.0f;
} }
else else
{ {
smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = A[idx];
smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = A[idx]; for(int col = 0; col < 8; col++)
smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = B[(col_offset+col)*ldb+idx];
for(int col = 0; col < 8; col++) }
smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = B[(col_offset+col)*ldb+idx];
}
__syncthreads(); __syncthreads();
wmma::load_matrix_sync(a_frag, &(smem_A[0]), 16); // 111 mu wmma::load_matrix_sync(a_frag, &(smem_A[0]), 16); // 111 mu
wmma::load_matrix_sync(b_frag, &(smem_B[0]), 16); // 35 mu wmma::load_matrix_sync(b_frag, &(smem_B[0]), 16); // 35 mu
wmma::load_matrix_sync(a2_frag, &(smem_A[a_tile_offset]), 16); // 111 mu wmma::load_matrix_sync(a2_frag, &(smem_A[half_warp_id*a_tile_offset]), 16); // 111 mu
wmma::load_matrix_sync(b2_frag, &(smem_B[b_tile_offset]), 16); // 35 mu wmma::load_matrix_sync(b2_frag, &(smem_B[half_warp_id*b_tile_offset]), 16); // 35 mu
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
wmma::mma_sync(c_frag, a2_frag, b2_frag, c_frag); wmma::mma_sync(c_frag, a2_frag, b2_frag, c_frag);
} }
// 129 mu // 129 mu
wmma::store_matrix_sync(&(smem_C[0]), c_frag, 8, wmma::mem_row_major); wmma::store_matrix_sync(&(smem_C[half_warp_id*c_tile_offset]), c_frag, 8, wmma::mem_row_major);
__syncthreads(); __syncthreads();
//if(threadIdx.x >= 16){ return; } //if(threadIdx.x >= 16){ return; }
//printf("%i %f\n", threadIdx.x, (float)smem_C[threadIdx.x]); //printf("%i %f\n", threadIdx.x, (float)smem_C[threadIdx.x]);
//if(threadIdx.x < 32) //if(threadIdx.x < 32)
//if(warp_lane < 8 && warp_id > 0) if(half_warp_lane < 8 && half_warp_id > 0)
// //local_C[warp_lane] = smem_C[warp_lane + (warp_id*32*8)]; //local_C[warp_lane] = smem_C[warp_lane + (warp_id*32*8)];
// atomicAdd(&(smem_C[warp_lane]), smem_C[warp_lane + (warp_id*32*8)]); atomicAdd(&(smem_C[half_warp_lane]), smem_C[half_warp_lane + (half_warp_id*c_tile_offset)]);
//__syncthreads(); __syncthreads();
//local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum()); //local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum());
//if(threadIdx.x == 0) //if(threadIdx.x == 0)

View File

@ -693,7 +693,8 @@ template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out
//gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
if(bits == 16) if(bits == 16)
//gemm_device<T, 16, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<T, 16, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
gemm_device<T, 16, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<T, 16, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
gemm_device<T, 16, 64><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
} }
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)