Slow non-vector 530.
This commit is contained in:
parent
ad07d254fb
commit
604bb3fb57
106
csrc/kernels.cu
106
csrc/kernels.cu
|
@ -3041,6 +3041,7 @@ template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_l
|
|||
}
|
||||
}
|
||||
|
||||
#define WARPS 1
|
||||
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
|
||||
{
|
||||
|
||||
|
@ -3059,9 +3060,9 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
|||
T local_B[64/BITS];
|
||||
T local_C[8];
|
||||
|
||||
__shared__ T smem_A[4*32*16];
|
||||
__shared__ T smem_B[4*16*8];
|
||||
__shared__ T smem_C[4*32*8];
|
||||
__shared__ T smem_A[WARPS*32*16];
|
||||
__shared__ T smem_B[WARPS*16*8];
|
||||
__shared__ T smem_C[WARPS*32*8];
|
||||
|
||||
wmma::fragment<wmma::matrix_a, 32, 8, 16, half, wmma::row_major> a_frag;
|
||||
wmma::fragment<wmma::matrix_b, 32, 8, 16, half, wmma::col_major> b_frag;
|
||||
|
@ -3070,13 +3071,13 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
|||
wmma::fill_fragment(c_frag, 0.0f);
|
||||
|
||||
|
||||
for(int i = threadIdx.x; i < 32*16*4; i+=blockDim.x)
|
||||
for(int i = threadIdx.x; i < 32*16*WARPS; i+=blockDim.x)
|
||||
smem_A[i] = T(0);
|
||||
|
||||
for(int i = threadIdx.x; i < 32*8*4; i+=blockDim.x)
|
||||
for(int i = threadIdx.x; i < 32*8*WARPS; i+=blockDim.x)
|
||||
smem_B[i] = T(0);
|
||||
|
||||
for(int i = threadIdx.x; i < 32*8*THREADS/32; i+=blockDim.x)
|
||||
for(int i = threadIdx.x; i < 32*8*WARPS; i+=blockDim.x)
|
||||
smem_C[i] = T(0);
|
||||
__syncthreads();
|
||||
|
||||
|
@ -3084,91 +3085,48 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
|||
for(int k = 0; k < 8; k++)
|
||||
local_C[k] = T(0);
|
||||
|
||||
int block_idx = 0;
|
||||
//int block_idx = 0;
|
||||
//for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x)
|
||||
for(int base_idx = 0; base_idx < K; base_idx+=64)
|
||||
for(int base_idx = 0; base_idx < K; base_idx+=16)
|
||||
{
|
||||
int idx = base_idx + threadIdx.x;
|
||||
|
||||
int tidx = threadIdx.x*4;
|
||||
|
||||
if(base_idx % (4*blockDim.x) == 0)
|
||||
if(threadIdx.x < 16)
|
||||
{
|
||||
vector_load<T, int2, 64/BITS>(local_A, A, base_idx+tidx, base_idx+tidx, K); // 54 mu
|
||||
block_idx = 0;
|
||||
if(idx >= K)
|
||||
{
|
||||
smem_A[threadIdx.x] = 0.0f;
|
||||
smem_B[threadIdx.x] = 0.0f;
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
smem_A[threadIdx.x] = A[idx];
|
||||
|
||||
for(int col = 0; col < 8; col++)
|
||||
smem_B[threadIdx.x + (col*16)] = B[(col_offset+col)*ldb+idx];
|
||||
}
|
||||
}
|
||||
|
||||
for(int k = 0; k < 4; k++)
|
||||
{
|
||||
if((threadIdx.x >= block_idx*16) && (threadIdx.x < (block_idx+1)*16))
|
||||
smem_A[(threadIdx.x % 16) + (32*16*k)] = local_A[k]; // 54 mu
|
||||
}
|
||||
block_idx += 1;
|
||||
|
||||
// 4 warps, 1 warps loads in total 4*32=64 values -> 4 columns at a time
|
||||
// we need 8 columns, so 2 loads and smem stores
|
||||
// we need a half-warp to load one column at a time
|
||||
for(int j = 0; j < 2; j++)
|
||||
{
|
||||
int col = warp_id + (j*4);
|
||||
int offset_B = (col_offset+col)*ldb;
|
||||
vector_load<T, int2, 64/BITS>(local_B, B, offset_B+base_idx+warp_lane*4, base_idx+warp_lane*4, K); // 171 mu
|
||||
|
||||
|
||||
//#pragma unroll 4
|
||||
//for(int k = 0; k < 4; k++)
|
||||
// if((float)local_B[k] != 0.0)
|
||||
// printf("%i %i %i %i %f\n", j, warp_id, warp_lane, k, (float)local_B[k]);
|
||||
|
||||
// load and store is different
|
||||
// we wnat to load 64 consequitive values with one warp
|
||||
// but we need to store those across 4 fragments since
|
||||
// the max column width is 16.
|
||||
|
||||
// each 16 values a new tile for each warp
|
||||
//int tile_idx = warp_lane/16;
|
||||
#pragma unroll 4
|
||||
for(int k = 0; k < 4; k++)
|
||||
smem_B[(warp_lane % 16) + (col*16) + (k*16*8)] = local_B[k]; // 171 mu
|
||||
}
|
||||
|
||||
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//if(threadIdx.x == 0)
|
||||
// for(int w = 0; w < 4; w++)
|
||||
// for(int trow = 0; trow < 32; trow++)
|
||||
// for(int tcol = 0; tcol < 16; tcol++)
|
||||
// if((float)smem_A[trow + tcol*32 + (w*32*16)] != 0.0)
|
||||
// printf("A %i %i %i = %f\n", w, trow, tcol, (float) smem_B[trow + tcol*16]);
|
||||
|
||||
//if(threadIdx.x == 0)
|
||||
// for(int w = 0; w < 4; w++)
|
||||
// for(int trow = 0; trow < 16; trow++)
|
||||
// for(int tcol = 0; tcol < 8; tcol++)
|
||||
// if((float)smem_B[trow + tcol*16 + (w*16*8)] != 0.0)
|
||||
// printf("B %i %i %i = %f\n", w, trow, tcol, (float) smem_B[trow + tcol*16]);
|
||||
|
||||
|
||||
//__syncthreads();
|
||||
|
||||
wmma::load_matrix_sync(a_frag, &(smem_A[warp_id*32*16]), 16); // 111 mu
|
||||
wmma::load_matrix_sync(b_frag, &(smem_B[warp_id*16*8]), 16); // 35 mu
|
||||
wmma::load_matrix_sync(a_frag, &(smem_A[0]), 16); // 111 mu
|
||||
wmma::load_matrix_sync(b_frag, &(smem_B[0]), 16); // 35 mu
|
||||
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
|
||||
}
|
||||
|
||||
// 129 mu
|
||||
wmma::store_matrix_sync(&(smem_C[warp_id*32*8]), c_frag, 8, wmma::mem_row_major);
|
||||
wmma::store_matrix_sync(&(smem_C[0]), c_frag, 8, wmma::mem_row_major);
|
||||
__syncthreads();
|
||||
|
||||
//if(threadIdx.x >= 16){ return; }
|
||||
//printf("%i %f\n", threadIdx.x, (float)smem_C[threadIdx.x]);
|
||||
|
||||
//if(threadIdx.x < 32)
|
||||
if(warp_lane < 8 && warp_id > 0)
|
||||
//local_C[warp_lane] = smem_C[warp_lane + (warp_id*32*8)];
|
||||
atomicAdd(&(smem_C[warp_lane]), smem_C[warp_lane + (warp_id*32*8)]);
|
||||
__syncthreads();
|
||||
//if(warp_lane < 8 && warp_id > 0)
|
||||
// //local_C[warp_lane] = smem_C[warp_lane + (warp_id*32*8)];
|
||||
// atomicAdd(&(smem_C[warp_lane]), smem_C[warp_lane + (warp_id*32*8)]);
|
||||
//__syncthreads();
|
||||
|
||||
//local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum());
|
||||
//if(threadIdx.x == 0)
|
||||
|
@ -3487,12 +3445,14 @@ __global__ void with_staging_unified(float const* global_in, float * global_out,
|
|||
template __global__ void gemm_device<half, 32, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||
//template __global__ void gemm_device<float, 16, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
|
||||
template __global__ void gemm_device<half, 32, 32>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||
template __global__ void gemm_device<half, 32, 64>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||
// these are not used and make no sense, but the compiler needs them
|
||||
|
||||
//template __global__ void gemm_device<float, 32, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
|
||||
template __global__ void gemm_device<half, 16, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||
//template __global__ void gemm_device<float, 32, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
|
||||
template __global__ void gemm_device<half, 16, 32>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||
template __global__ void gemm_device<half, 16, 64>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||
|
||||
template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
|
|
|
@ -692,8 +692,8 @@ template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out
|
|||
//gemm_device<T, 32, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||
//gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||
if(bits == 16)
|
||||
gemm_device<T, 16, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||
//gemm_device<T, 16, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||
//gemm_device<T, 16, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||
gemm_device<T, 16, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||
}
|
||||
|
||||
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
|
||||
|
|
Loading…
Reference in New Issue
Block a user