Slow non-vector 530.

This commit is contained in:
Tim Dettmers 2023-04-30 18:06:01 -07:00
parent ad07d254fb
commit 604bb3fb57
2 changed files with 35 additions and 75 deletions

View File

@ -3041,6 +3041,7 @@ template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_l
}
}
#define WARPS 1
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
{
@ -3059,9 +3060,9 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
T local_B[64/BITS];
T local_C[8];
__shared__ T smem_A[4*32*16];
__shared__ T smem_B[4*16*8];
__shared__ T smem_C[4*32*8];
__shared__ T smem_A[WARPS*32*16];
__shared__ T smem_B[WARPS*16*8];
__shared__ T smem_C[WARPS*32*8];
wmma::fragment<wmma::matrix_a, 32, 8, 16, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, 32, 8, 16, half, wmma::col_major> b_frag;
@ -3070,13 +3071,13 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
wmma::fill_fragment(c_frag, 0.0f);
for(int i = threadIdx.x; i < 32*16*4; i+=blockDim.x)
for(int i = threadIdx.x; i < 32*16*WARPS; i+=blockDim.x)
smem_A[i] = T(0);
for(int i = threadIdx.x; i < 32*8*4; i+=blockDim.x)
for(int i = threadIdx.x; i < 32*8*WARPS; i+=blockDim.x)
smem_B[i] = T(0);
for(int i = threadIdx.x; i < 32*8*THREADS/32; i+=blockDim.x)
for(int i = threadIdx.x; i < 32*8*WARPS; i+=blockDim.x)
smem_C[i] = T(0);
__syncthreads();
@ -3084,91 +3085,48 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
for(int k = 0; k < 8; k++)
local_C[k] = T(0);
int block_idx = 0;
//int block_idx = 0;
//for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x)
for(int base_idx = 0; base_idx < K; base_idx+=64)
for(int base_idx = 0; base_idx < K; base_idx+=16)
{
int idx = base_idx + threadIdx.x;
int tidx = threadIdx.x*4;
if(base_idx % (4*blockDim.x) == 0)
if(threadIdx.x < 16)
{
vector_load<T, int2, 64/BITS>(local_A, A, base_idx+tidx, base_idx+tidx, K); // 54 mu
block_idx = 0;
if(idx >= K)
{
smem_A[threadIdx.x] = 0.0f;
smem_B[threadIdx.x] = 0.0f;
}
else
{
smem_A[threadIdx.x] = A[idx];
for(int col = 0; col < 8; col++)
smem_B[threadIdx.x + (col*16)] = B[(col_offset+col)*ldb+idx];
}
}
for(int k = 0; k < 4; k++)
{
if((threadIdx.x >= block_idx*16) && (threadIdx.x < (block_idx+1)*16))
smem_A[(threadIdx.x % 16) + (32*16*k)] = local_A[k]; // 54 mu
}
block_idx += 1;
// 4 warps, 1 warps loads in total 4*32=64 values -> 4 columns at a time
// we need 8 columns, so 2 loads and smem stores
// we need a half-warp to load one column at a time
for(int j = 0; j < 2; j++)
{
int col = warp_id + (j*4);
int offset_B = (col_offset+col)*ldb;
vector_load<T, int2, 64/BITS>(local_B, B, offset_B+base_idx+warp_lane*4, base_idx+warp_lane*4, K); // 171 mu
//#pragma unroll 4
//for(int k = 0; k < 4; k++)
// if((float)local_B[k] != 0.0)
// printf("%i %i %i %i %f\n", j, warp_id, warp_lane, k, (float)local_B[k]);
// load and store is different
// we wnat to load 64 consequitive values with one warp
// but we need to store those across 4 fragments since
// the max column width is 16.
// each 16 values a new tile for each warp
//int tile_idx = warp_lane/16;
#pragma unroll 4
for(int k = 0; k < 4; k++)
smem_B[(warp_lane % 16) + (col*16) + (k*16*8)] = local_B[k]; // 171 mu
}
__syncthreads();
//if(threadIdx.x == 0)
// for(int w = 0; w < 4; w++)
// for(int trow = 0; trow < 32; trow++)
// for(int tcol = 0; tcol < 16; tcol++)
// if((float)smem_A[trow + tcol*32 + (w*32*16)] != 0.0)
// printf("A %i %i %i = %f\n", w, trow, tcol, (float) smem_B[trow + tcol*16]);
//if(threadIdx.x == 0)
// for(int w = 0; w < 4; w++)
// for(int trow = 0; trow < 16; trow++)
// for(int tcol = 0; tcol < 8; tcol++)
// if((float)smem_B[trow + tcol*16 + (w*16*8)] != 0.0)
// printf("B %i %i %i = %f\n", w, trow, tcol, (float) smem_B[trow + tcol*16]);
//__syncthreads();
wmma::load_matrix_sync(a_frag, &(smem_A[warp_id*32*16]), 16); // 111 mu
wmma::load_matrix_sync(b_frag, &(smem_B[warp_id*16*8]), 16); // 35 mu
wmma::load_matrix_sync(a_frag, &(smem_A[0]), 16); // 111 mu
wmma::load_matrix_sync(b_frag, &(smem_B[0]), 16); // 35 mu
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
// 129 mu
wmma::store_matrix_sync(&(smem_C[warp_id*32*8]), c_frag, 8, wmma::mem_row_major);
wmma::store_matrix_sync(&(smem_C[0]), c_frag, 8, wmma::mem_row_major);
__syncthreads();
//if(threadIdx.x >= 16){ return; }
//printf("%i %f\n", threadIdx.x, (float)smem_C[threadIdx.x]);
//if(threadIdx.x < 32)
if(warp_lane < 8 && warp_id > 0)
//local_C[warp_lane] = smem_C[warp_lane + (warp_id*32*8)];
atomicAdd(&(smem_C[warp_lane]), smem_C[warp_lane + (warp_id*32*8)]);
__syncthreads();
//if(warp_lane < 8 && warp_id > 0)
// //local_C[warp_lane] = smem_C[warp_lane + (warp_id*32*8)];
// atomicAdd(&(smem_C[warp_lane]), smem_C[warp_lane + (warp_id*32*8)]);
//__syncthreads();
//local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum());
//if(threadIdx.x == 0)
@ -3487,12 +3445,14 @@ __global__ void with_staging_unified(float const* global_in, float * global_out,
template __global__ void gemm_device<half, 32, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
//template __global__ void gemm_device<float, 16, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 32>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 64>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
// these are not used and make no sense, but the compiler needs them
//template __global__ void gemm_device<float, 32, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
//template __global__ void gemm_device<float, 32, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 32>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 64>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);

View File

@ -692,8 +692,8 @@ template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out
//gemm_device<T, 32, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
//gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
if(bits == 16)
gemm_device<T, 16, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
//gemm_device<T, 16, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
//gemm_device<T, 16, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
gemm_device<T, 16, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
}
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)