diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 24b004b..5a6db7d 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3041,6 +3041,7 @@ template __device__ inline void vector_l } } +#define WARPS 1 template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) { @@ -3059,9 +3060,9 @@ template __global__ void gemm_device(int M, T local_B[64/BITS]; T local_C[8]; - __shared__ T smem_A[4*32*16]; - __shared__ T smem_B[4*16*8]; - __shared__ T smem_C[4*32*8]; + __shared__ T smem_A[WARPS*32*16]; + __shared__ T smem_B[WARPS*16*8]; + __shared__ T smem_C[WARPS*32*8]; wmma::fragment a_frag; wmma::fragment b_frag; @@ -3070,13 +3071,13 @@ template __global__ void gemm_device(int M, wmma::fill_fragment(c_frag, 0.0f); - for(int i = threadIdx.x; i < 32*16*4; i+=blockDim.x) + for(int i = threadIdx.x; i < 32*16*WARPS; i+=blockDim.x) smem_A[i] = T(0); - for(int i = threadIdx.x; i < 32*8*4; i+=blockDim.x) + for(int i = threadIdx.x; i < 32*8*WARPS; i+=blockDim.x) smem_B[i] = T(0); - for(int i = threadIdx.x; i < 32*8*THREADS/32; i+=blockDim.x) + for(int i = threadIdx.x; i < 32*8*WARPS; i+=blockDim.x) smem_C[i] = T(0); __syncthreads(); @@ -3084,91 +3085,48 @@ template __global__ void gemm_device(int M, for(int k = 0; k < 8; k++) local_C[k] = T(0); - int block_idx = 0; + //int block_idx = 0; //for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x) - for(int base_idx = 0; base_idx < K; base_idx+=64) + for(int base_idx = 0; base_idx < K; base_idx+=16) { + int idx = base_idx + threadIdx.x; - int tidx = threadIdx.x*4; - - if(base_idx % (4*blockDim.x) == 0) + if(threadIdx.x < 16) { - vector_load(local_A, A, base_idx+tidx, base_idx+tidx, K); // 54 mu - block_idx = 0; + if(idx >= K) + { + smem_A[threadIdx.x] = 0.0f; + smem_B[threadIdx.x] = 0.0f; + } + else + { + + smem_A[threadIdx.x] = A[idx]; + + for(int col = 0; col < 8; col++) + smem_B[threadIdx.x + (col*16)] = B[(col_offset+col)*ldb+idx]; + } } - for(int k = 0; k < 4; k++) - { - if((threadIdx.x >= block_idx*16) && (threadIdx.x < (block_idx+1)*16)) - smem_A[(threadIdx.x % 16) + (32*16*k)] = local_A[k]; // 54 mu - } - block_idx += 1; - - // 4 warps, 1 warps loads in total 4*32=64 values -> 4 columns at a time - // we need 8 columns, so 2 loads and smem stores - // we need a half-warp to load one column at a time - for(int j = 0; j < 2; j++) - { - int col = warp_id + (j*4); - int offset_B = (col_offset+col)*ldb; - vector_load(local_B, B, offset_B+base_idx+warp_lane*4, base_idx+warp_lane*4, K); // 171 mu - - - //#pragma unroll 4 - //for(int k = 0; k < 4; k++) - // if((float)local_B[k] != 0.0) - // printf("%i %i %i %i %f\n", j, warp_id, warp_lane, k, (float)local_B[k]); - - // load and store is different - // we wnat to load 64 consequitive values with one warp - // but we need to store those across 4 fragments since - // the max column width is 16. - - // each 16 values a new tile for each warp - //int tile_idx = warp_lane/16; - #pragma unroll 4 - for(int k = 0; k < 4; k++) - smem_B[(warp_lane % 16) + (col*16) + (k*16*8)] = local_B[k]; // 171 mu - } - - - __syncthreads(); - //if(threadIdx.x == 0) - // for(int w = 0; w < 4; w++) - // for(int trow = 0; trow < 32; trow++) - // for(int tcol = 0; tcol < 16; tcol++) - // if((float)smem_A[trow + tcol*32 + (w*32*16)] != 0.0) - // printf("A %i %i %i = %f\n", w, trow, tcol, (float) smem_B[trow + tcol*16]); - - //if(threadIdx.x == 0) - // for(int w = 0; w < 4; w++) - // for(int trow = 0; trow < 16; trow++) - // for(int tcol = 0; tcol < 8; tcol++) - // if((float)smem_B[trow + tcol*16 + (w*16*8)] != 0.0) - // printf("B %i %i %i = %f\n", w, trow, tcol, (float) smem_B[trow + tcol*16]); - - - //__syncthreads(); - - wmma::load_matrix_sync(a_frag, &(smem_A[warp_id*32*16]), 16); // 111 mu - wmma::load_matrix_sync(b_frag, &(smem_B[warp_id*16*8]), 16); // 35 mu + wmma::load_matrix_sync(a_frag, &(smem_A[0]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[0]), 16); // 35 mu wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); } // 129 mu - wmma::store_matrix_sync(&(smem_C[warp_id*32*8]), c_frag, 8, wmma::mem_row_major); + wmma::store_matrix_sync(&(smem_C[0]), c_frag, 8, wmma::mem_row_major); __syncthreads(); //if(threadIdx.x >= 16){ return; } //printf("%i %f\n", threadIdx.x, (float)smem_C[threadIdx.x]); //if(threadIdx.x < 32) - if(warp_lane < 8 && warp_id > 0) - //local_C[warp_lane] = smem_C[warp_lane + (warp_id*32*8)]; - atomicAdd(&(smem_C[warp_lane]), smem_C[warp_lane + (warp_id*32*8)]); - __syncthreads(); + //if(warp_lane < 8 && warp_id > 0) + // //local_C[warp_lane] = smem_C[warp_lane + (warp_id*32*8)]; + // atomicAdd(&(smem_C[warp_lane]), smem_C[warp_lane + (warp_id*32*8)]); + //__syncthreads(); //local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum()); //if(threadIdx.x == 0) @@ -3487,12 +3445,14 @@ __global__ void with_staging_unified(float const* global_in, float * global_out, template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); // these are not used and make no sense, but the compiler needs them //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); diff --git a/csrc/ops.cu b/csrc/ops.cu index d83fc6e..5c4f9c0 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -692,8 +692,8 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); if(bits == 16) - gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); } template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)