diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index bb3cde3..774e954 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1429,7 +1429,7 @@ def cutlass3_gemm( m = sB[1] k = sB[0] - lda = B.stride()[(1 if transposed_B else 0)] + lda = B.stride()[0] ldc = sB[1] elif len(sB) == 3: # special case @@ -1446,7 +1446,7 @@ def cutlass3_gemm( n = sA[2] k = sB[0] * sB[1] - lda = m + lda = n ldb = sA[2] ldc = m @@ -1454,7 +1454,7 @@ def cutlass3_gemm( # B^T @ A^T = C^T # [km, nk -> mn] - lda = ldb = ldc = 1 + #lda = ldb = ldc = 1 #lda = 1 #print(m, n, k, lda, ldb, ldc) is_on_gpu([B, A, out]) @@ -1466,7 +1466,7 @@ def cutlass3_gemm( ldc = ct.c_int32(ldc) alpha = ct.c_float(1.0) beta = ct.c_float(0.0) - lib.ccutlass_gemm(m, n, k, alpha, get_ptr(A), ldb, get_ptr(B), lda, beta, get_ptr(out), ldc) + lib.ccutlass_gemm(m, n, k, alpha, get_ptr(A), lda, get_ptr(B), ldb, beta, get_ptr(out), ldc) return out diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 775716f..91169dd 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2947,9 +2947,11 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * //// 9. write outputs to matmul output matrix //} + +#define ROWS 2 __global__ void gemm_device(int M, int N, int K, float const* A, - float const* B, + float* B, float * out, int lda, int ldb, int ldc, float alpha, float beta) { @@ -2958,29 +2960,106 @@ __global__ void gemm_device(int M, int N, int K, // 2. Dequantize B // 3. Fetch data from A and multiply - typedef cub::BlockLoad LoadA; - __shared__ typename LoadA::TempStorage loada; - float dataA[1]; + typedef cub::BlockLoad LoadA; + //__shared__ typename LoadA::TempStorage loada; + typedef cub::BlockLoad LoadB; + //__shared__ typename LoadB::TempStorage loadb; + typedef cub::BlockReduce BlockReduce; + // Allocate shared memory for BlockReduce + //__shared__ typename BlockReduce::TempStorage reduce; + + __shared__ union { + typename BlockReduce::TempStorage reduce; + typename LoadB::TempStorage loadb; + typename LoadA::TempStorage loada; + } temp_storage; + + + float dataA[4]; + float local_B[4]; + float local_accC[ROWS]; int valid_items = 0; + const int warp_id = threadIdx.x/32; + const int warp_lane = threadIdx.x % 32; + const int col_offset = blockIdx.x * 8; - __shared__ float[16*256] tileA; + __shared__ float tileA[ROWS*1024]; + __shared__ float accumulatorC[ROWS*8]; + + //#pragma unroll 8 + //for(int i = 0; i < 8; i++) + // tileA[threadIdx.x + (i*256)] = 0.0f; + //__syncthreads(); + if(threadIdx.x < 64) + accumulatorC[threadIdx.x] = 0.0f; + __syncthreads(); - for(int idxA = 0; idxA < M*K; idxA+= 256) + for(int inner_idx = 0; inner_idx < K; inner_idx+= 1024) { - valid_items = M*K - idxA > 256 ? 256 : M*K - idxA; + valid_items = K - inner_idx > 1024 ? 1024 : K - inner_idx; int baserow = 0; - for(int row = baserow; row < baserow+16 && row < M + ; row++) + for(int row = baserow; row < (baserow+ROWS) && row < N; row++) { - LoadA(loada).Load(&(A[(row*lda) + i]), dataA, valid_items, 0.0f); - tileA[row*256 + threadIdx.x] = dataA[0]; + LoadA(temp_storage.loada).Load(&(A[(row*K) + inner_idx]), dataA, valid_items, 0.0f); + + #pragma unroll 4 + for(int k = 0; k < 4; k++) + tileA[row*1024 + threadIdx.x + (k*blockDim.x)] = dataA[k]; + __syncthreads(); } - baserow += 16; + baserow += ROWS; + // load 16 columns from B at a time. B is transposed, so its like loading rows + // each warp loads one row + // each thread loads 128 byte + // col: inner_idx + warp_lane + // row: ldb*(offset + warp_id) + for(int col = 0; col < 8 && (col_offset + col) < M; col++) + { + int colB = col_offset + col; + + for(int k = 0; k < ROWS; k++) + local_accC[k] = 0.0f; + + int base_idxB = ldb*colB; + valid_items = K - inner_idx > 1024 ? 1024 : K - inner_idx; + LoadB(temp_storage.loadb).Load(&(B[base_idxB + inner_idx]), local_B, valid_items, 0.0f); + __syncthreads(); + + for(int row = 0; row < ROWS && row < N; row++) + { + #pragma unroll 4 + for(int k = 0; k < 4; k++) + { + int idxA = row*1024 + threadIdx.x + (blockDim.x*k); + local_accC[row] += tileA[idxA]*local_B[k]; + } + + local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum()); + if(threadIdx.x == 0) + atomicAdd(&accumulatorC[row*8 + col], local_accC[row]); + } + } } + for(int row = 0; row < ROWS && row < N; row++) + { + int out_idx = ldc*row + col_offset; + + //if(threadIdx.x < 8) + // if(accumulatorC[row*8 + threadIdx.x] != 0.0) + // printf("%i %i %i %i %f idx %i %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx, blockIdx.x); + + if(threadIdx.x < 8 && (col_offset + threadIdx.x) < M) + { + //printf("%i %i %i %i %f idx %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx); + out[out_idx + threadIdx.x] = accumulatorC[row*8 + threadIdx.x]; + } + } + } diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 37e214a..55397e7 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -140,7 +140,7 @@ __global__ void with_staging_unified(float const* global_in, float * global_out, __global__ void gemm_device(int M, int N, int K, float const* A, - float const* B, + float * B, float * out, int lda, int ldb, int ldc, float alpha, float beta); diff --git a/csrc/ops.cu b/csrc/ops.cu index ee585bb..dd8fade 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -669,8 +669,6 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size) int threads = 256; int num_blocks = (n+(256*batch_size)+1)/(batch_size*256); - printf("%i %i\n", num_blocks, batch_size); - with_staging_unified<2><<>>(A, B, n, batch_size); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -680,15 +678,22 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size) void gemm_host(int m, int n, int k, float alpha, float const* A, int lda, - float const* B, int ldb, + float * B, int ldb, float beta, float * C, int ldc) { dim3 dimBlock(256); - int num_blocks = (n+31)/32; + int num_blocks = (m+7)/8; cout << num_blocks << endl; + cout << lda << endl; + cout << ldb << endl; + cout << ldc << endl; + + cout << m << endl; + cout << n << endl; + cout << k << endl; gemm_device <<< num_blocks, dimBlock, 0, 0 >>> (m, n, k, diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 83dd4e5..2f71966 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -193,7 +193,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows void gemm_host(int m, int n, int k, float alpha, float const* A, int ldA, - float const* B, int ldB, + float * B, int ldB, float beta, float * C, int ldC); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 170093f..6ec5501 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -24,7 +24,7 @@ void cppgemm(int m, int n, int k, float alpha, float const* A, int ldA, - float const* B, int ldB, + float * B, int ldB, float beta, float * C, int ldC) { gemm_host(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);} @@ -320,7 +320,7 @@ extern "C" void ccutlass_gemm(int m, int n, int k, float alpha, float const* A, int ldA, - float const* B, int ldB, + float * B, int ldB, float beta, float * C, int ldC) { cppgemm(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);} diff --git a/tests/test_functional.py b/tests/test_functional.py index 7dec375..087bc84 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2353,17 +2353,19 @@ def test_normal_map_tree(): def test_cutlass3_gemm(): - #A = torch.rand(2, 2).cuda() - #B = torch.rand(2, 2).cuda() - A = torch.arange(4).reshape(2, 2).float().cuda().contiguous() - B = torch.ones(2, 2).float().cuda() + A = torch.rand(2, 4092).cuda() + B = torch.rand(4*4092, 4092).cuda() - print('') - print(A) - print(B) + #print('') + #print(A) + #print(B.t()) - C1 = torch.matmul(A, B) - C2 = F.cutlass3_gemm(A, B) + C1 = torch.matmul(A, B.t()) + C2 = F.cutlass3_gemm(A, B.t()) + #print(C1) + #print(C2) + + torch.testing.assert_close(C1, C2) def test_pipeline_func():