diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 54a08a1..bb3cde3 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2341,3 +2341,8 @@ def extract_outliers(A, SA, idx): post_call(prev_device) return out + +def pipeline_test(A, batch_size): + out = torch.zeros_like(A) + lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) + return out diff --git a/csrc/kernels.cu b/csrc/kernels.cu index ed87c69..775716f 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -15,6 +15,9 @@ #include #include +#include +#include + #define HLF_MAX 65504 #define TH 1024 #define NUM 4 @@ -2983,6 +2986,51 @@ __global__ void gemm_device(int M, int N, int K, } +__device__ void compute(float* global_out, float const* shared_in) +{ + +} +template +__global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz) { + auto grid = cooperative_groups::this_grid(); + auto block = cooperative_groups::this_thread_block(); + assert(size == batch_sz * grid.size()); // Assume input size fits batch_sz * grid_size + + extern __shared__ float shared[]; // stages_count * block.size() * sizeof(int) bytes + size_t shared_offset[stages_count]; + for (int s = 0; s < stages_count; ++s) shared_offset[s] = s * block.size(); + + __shared__ cuda::pipeline_shared_state< + cuda::thread_scope::thread_scope_block, + stages_count + > shared_state; + auto pipeline = cuda::make_pipeline(block, &shared_state); + + auto block_batch = [&](size_t batch) -> int { + return block.group_index().x * block.size() + grid.size() * batch; + }; + + // compute_batch: next batch to process + // fetch_batch: next batch to fetch from global memory + for (size_t compute_batch = 0, fetch_batch = 0; compute_batch < batch_sz; ++compute_batch) { + // The outer loop iterates over the computation of the batches + for (; fetch_batch < batch_sz && fetch_batch < (compute_batch + stages_count); ++fetch_batch) { + // This inner loop iterates over the memory transfers, making sure that the pipeline is always full + pipeline.producer_acquire(); + size_t shared_idx = fetch_batch % stages_count; + size_t batch_idx = fetch_batch; + size_t block_batch_idx = block_batch(batch_idx); + cuda::memcpy_async(block, shared + shared_offset[shared_idx], global_in + block_batch_idx, sizeof(float) * block.size(), pipeline); + pipeline.producer_commit(); + } + pipeline.consumer_wait(); + int shared_idx = compute_batch % stages_count; + int batch_idx = compute_batch; + compute(global_out + block_batch(batch_idx), shared + shared_offset[shared_idx]); + pipeline.consumer_release(); + } +} + //============================================================== // TEMPLATE DEFINITIONS @@ -3004,6 +3052,7 @@ __global__ void gemm_device(int M, int N, int K, //template __global__ void kMatmul_inference_4bit(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB); +template __global__ void with_staging_unified<2>(float const* global_in, float * global_out, size_t size, size_t batch_sz); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index ba6de59..37e214a 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -135,6 +135,8 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * // TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, // TC * out, CStride dC, CBlockLayout , CThreadLayout tC, // Alpha alpha, Beta beta); +template +__global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz); __global__ void gemm_device(int M, int N, int K, float const* A, diff --git a/csrc/ops.cu b/csrc/ops.cu index 8933927..ee585bb 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -663,6 +663,17 @@ template void extractOutliers(char * A, int *idx, char *out, int id } +void pipeline_test(float *A, float *B, size_t n, size_t batch_size) +{ + + int threads = 256; + int num_blocks = (n+(256*batch_size)+1)/(batch_size*256); + + printf("%i %i\n", num_blocks, batch_size); + + with_staging_unified<2><<>>(A, B, n, batch_size); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 843a9bb..83dd4e5 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -197,4 +197,6 @@ void gemm_host(int m, int n, int k, float beta, float * C, int ldC); + +void pipeline_test(float *A, float *B, size_t n, size_t batch_size); #endif diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index c6de62d..170093f 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -315,6 +315,7 @@ extern "C" void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); } void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); } + void cpipeline_test(float *A, float *B, size_t n, size_t batch_size){ pipeline_test(A, B, n, batch_size); } void ccutlass_gemm(int m, int n, int k, float alpha, diff --git a/tests/test_functional.py b/tests/test_functional.py index dd41972..7dec375 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2366,3 +2366,8 @@ def test_cutlass3_gemm(): C2 = F.cutlass3_gemm(A, B) +def test_pipeline_func(): + a = torch.rand(2, 4).cuda() + out = F.pipeline_test(a, 2) + print(a) + print(out)