From ec38ba95b0cd6bf3dadfccf366cd8917acf59c4b Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sat, 6 May 2023 11:14:06 -0700 Subject: [PATCH] Added paging. --- bitsandbytes/cextension.py | 2 + bitsandbytes/functional.py | 55 +++++++++++++++++++++++++++ csrc/kernels.cu | 76 ++++++++++---------------------------- csrc/kernels.cuh | 18 +-------- csrc/ops.cu | 25 ++++++++----- csrc/ops.cuh | 9 ++++- csrc/pythonInterface.c | 32 +++++++++++++++- tests/test_functional.py | 40 +++++++++++++++++--- 8 files changed, 167 insertions(+), 90 deletions(-) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 8adca93..17c2a46 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -26,6 +26,8 @@ try: lib.cadam_8bit_blockwise_fp32 lib.get_context.restype = ct.c_void_p lib.get_cusparse.restype = ct.c_void_p + lib.cget_managed_ptr.restype = ct.c_void_p + lib.cget_stream.restype = ct.c_void_p COMPILED_WITH_CUDA = True except AttributeError: warn("The installed version of bitsandbytes was compiled without GPU support. " diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index e5b1bf7..f548475 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -130,6 +130,61 @@ class Cusparse_Context: cls._instance.initialize() return cls._instance +dtype2bytes = {} +dtype2bytes[torch.float32] = 4 +dtype2bytes[torch.float16] = 2 +dtype2bytes[torch.bfloat16] = 2 +dtype2bytes[torch.uint8] = 1 +dtype2bytes[torch.int8] = 1 + +def get_paged(*shape, dtype=torch.float32, device=torch.device('cuda', index=0)): + num_bytes = dtype2bytes[dtype]*prod(shape) + cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) + c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) + new_array = np.ctypeslib.as_array(c_ptr, shape=shape) + out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape)) + out.is_paged = True + out.page_deviceid = device.index + return out + +def prefetch_tensor(A, to_cpu=False): + assert A.is_paged, 'Only paged tensors can be prefetched!' + if to_cpu: + deviceid = -1 + else: + deviceid = A.page_deviceid + + num_bytes = dtype2bytes[A.dtype]*A.numel() + lib.cprefetch(get_ptr(A), ct.c_size_t(num_bytes), ct.c_int32(deviceid)) + +def elementwise_func(func_name, A, B, value, prefetch=True): + func = None + if A.dtype == torch.float32: + func = getattr(lib, f'c{func_name}_fp32', None) + cvalue = ct.c_float(value) + elif A.dtype == torch.uint8: + func = getattr(lib, f'c{func_name}_uint8', None) + cvalue = ct.c_uint8(value) + + if func is None: raise NotImplementedError(f'Function not implemented: {func_name}') + + is_managed = getattr(A, 'is_managed', False) + if is_managed and prefetch: + prefetch_tensor(A) + if B is not None: prefetch_tensor(B) + + func(get_ptr(A), get_ptr(B), cvalue, ct.c_int64(A.numel())) + if A.is_paged or B.is_paged: + # paged function are fully asynchronous + # if we return from this function, we want to the tensor + # to be in the correct state, that is the final state after the + # operation occured. So we synchronize. + torch.cuda.synchronize() + +def fill(A, value, device=None, prefetch=True): elementwise_func('fill', A, None, value) +def arange(A, device=None): elementwise_func('arange', A, None, 0) +def _mul(A, B, device=None): elementwise_func('_mul', A, B, 0) + def create_linear_map(signed=True, total_bits=8, add_zero=True): sign = (-1.0 if signed else 0.0) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 2373b91..e1a3155 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3522,49 +3522,23 @@ template __global__ void kgemm_4bit_inference(int M, i //} -__device__ void compute(float* global_out, float const* shared_in) +template __global__ void kfunc(T *A, T *B, T value, long n) { - -} -template -__global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz) { - auto grid = cooperative_groups::this_grid(); - auto block = cooperative_groups::this_thread_block(); - assert(size == batch_sz * grid.size()); // Assume input size fits batch_sz * grid_size - - extern __shared__ float shared[]; // stages_count * block.size() * sizeof(int) bytes - size_t shared_offset[stages_count]; - for (int s = 0; s < stages_count; ++s) shared_offset[s] = s * block.size(); - - __shared__ cuda::pipeline_shared_state< - cuda::thread_scope::thread_scope_block, - stages_count - > shared_state; - auto pipeline = cuda::make_pipeline(block, &shared_state); - - auto block_batch = [&](size_t batch) -> int { - return block.group_index().x * block.size() + grid.size() * batch; - }; - - // compute_batch: next batch to process - // fetch_batch: next batch to fetch from global memory - for (size_t compute_batch = 0, fetch_batch = 0; compute_batch < batch_sz; ++compute_batch) { - // The outer loop iterates over the computation of the batches - for (; fetch_batch < batch_sz && fetch_batch < (compute_batch + stages_count); ++fetch_batch) { - // This inner loop iterates over the memory transfers, making sure that the pipeline is always full - pipeline.producer_acquire(); - size_t shared_idx = fetch_batch % stages_count; - size_t batch_idx = fetch_batch; - size_t block_batch_idx = block_batch(batch_idx); - cuda::memcpy_async(block, shared + shared_offset[shared_idx], global_in + block_batch_idx, sizeof(float) * block.size(), pipeline); - pipeline.producer_commit(); - } - pipeline.consumer_wait(); - int shared_idx = compute_batch % stages_count; - int batch_idx = compute_batch; - compute(global_out + block_batch(batch_idx), shared + shared_offset[shared_idx]); - pipeline.consumer_release(); + for(long i = (blockDim.x*blockIdx.x) + threadIdx.x; i < n; i+=(blockDim.x*gridDim.x)) + { + switch(FUNC) + { + case FILL: + A[i] = (T)value; + break; + case ARANGE: + A[i] = (T)i; + break; + case _MUL: + A[i] = A[i]*B[i]; + break; } + } } @@ -3572,19 +3546,10 @@ __global__ void with_staging_unified(float const* global_in, float * global_out, // TEMPLATE DEFINITIONS //============================================================== -//template -//__global__ static -//__launch_bounds__(decltype(size(CThreadLayout{}))::value) -//void -//gemm_device(MShape M, NShape N, KShape K, -// TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA, -// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, -// TC * out, CStride dC, CBlockLayout , CThreadLayout tC, -// half alpha, half beta); +template __global__ void kfunc(float *A, float *B, float value, long n); +template __global__ void kfunc(unsigned char *A, unsigned char *B, unsigned char value, long n); +template __global__ void kfunc(float *A, float *B, float value, long n); +template __global__ void kfunc(float *A, float *B, float value, long n); // these are not used and make no sense, but the compiler needs them //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); @@ -3611,9 +3576,6 @@ template __global__ void gemm_device(int M, int N, int K, half * _ template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); - -//template __global__ void kMatmul_inference_4bit(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB); -template __global__ void with_staging_unified<2>(float const* global_in, float * global_out, size_t size, size_t batch_sz); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 4951031..29c6683 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -122,23 +122,9 @@ template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); -//template -//__global__ static -//__launch_bounds__(decltype(size(CThreadLayout{}))::value) -//void -//gemm_device(MShape M, NShape N, KShape K, -// TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA, -// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, -// TC * out, CStride dC, CBlockLayout , CThreadLayout tC, -// Alpha alpha, Beta beta); -template -__global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz); - template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kfunc(T *A, T *B, T value, long n); + #endif diff --git a/csrc/ops.cu b/csrc/ops.cu index 4d68436..7d13b71 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -663,16 +663,6 @@ template void extractOutliers(char * A, int *idx, char *out, int id } -void pipeline_test(float *A, float *B, size_t n, size_t batch_size) -{ - - int threads = 256; - int num_blocks = (n+(256*batch_size)+1)/(batch_size*256); - - with_staging_unified<2><<>>(A, B, n, batch_size); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); -} - template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) @@ -717,10 +707,25 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } +template void func(T *A, T *B, T value, long n) +{ + int threads = 512; + int blocks = n/threads; + blocks = n % threads == 0 ? blocks : blocks + 1; + blocks = blocks > 65535 ? 65535 : blocks; + kfunc<<>>(A, B, value, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + //============================================================== // TEMPLATE DEFINITIONS //============================================================== +template void func(float *A, float *B, float value, long n); +template void func(unsigned char *A, unsigned char *B, unsigned char value, long n); +template void func(float *A, float *B, float value, long n); +template void func(float *A, float *B, float value, long n); + template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); //template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 8919c60..e9d2e22 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -93,6 +93,13 @@ typedef enum DataType_t NF4 = 2, } DataType_t; +typedef enum Funcs_t +{ + FILL = 0, + ARANGE = 1, + _MUL = 2, +} Funcs_t; + class Context { public: @@ -193,6 +200,6 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template void func(T *A, T *B, T value, long n); -void pipeline_test(float *A, float *B, size_t n, size_t batch_size); #endif diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 26f16f2..7271430 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -28,6 +28,14 @@ void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int l void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } +#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ +void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func(A, B, value, n); } \ + +MAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL) +MAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL) +MAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) +MAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) + #define MAKE_FUNC32(fname, oname, gtype, gbits) \ void fname##32bit_g##gbits(gtype *g, gtype *p, \ @@ -314,7 +322,6 @@ extern "C" void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); } void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); } - void cpipeline_test(float *A, float *B, size_t n, size_t batch_size){ pipeline_test(A, B, n, batch_size); } //void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) //{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); } @@ -325,6 +332,29 @@ extern "C" void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } + void *cget_managed_ptr(size_t bytes) + { + void *ptr; + CUDA_CHECK_RETURN(cudaMallocManaged(&ptr, bytes, cudaMemAttachHost)); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + + return ptr; + } + + void cprefetch(void *ptr, size_t bytes, int device) + { + CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0)); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + } + + #define CMAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ + void c##fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ fname##_##type_name(A, B, value, n); } \ + + CMAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL) + CMAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL) + CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) + CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) + #endif void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } diff --git a/tests/test_functional.py b/tests/test_functional.py index dc4e40d..145c267 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2489,8 +2489,38 @@ def test_gemm_4bit(dtype): print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim)) print(dim, (max_err.item(), max_relerr.item())) -def test_pipeline_func(): - a = torch.rand(2, 4).cuda() - out = F.pipeline_test(a, 2) - print(a) - print(out) +def test_managed(): + n = 32*10 + A = F.get_paged(n, n, dtype=torch.float32) + B = F.get_paged(n, n, dtype=torch.uint8) + B2 = F.get_paged(n, n, dtype=torch.float32) + assert A.is_paged + assert B.is_paged + assert A.page_deviceid==0 + assert B.page_deviceid==0 + F.fill(A, 17.0) + F.fill(B, 17) + F.fill(B2, 2) + assert (A==17).sum().item() == n*n + assert (B==17).sum().item() == n*n + C = A*B.float() + assert (C==289).sum().item() == n*n + F._mul(A, B2) + F._mul(A, B2) + F._mul(A, B2) + assert (A==17*(2**3)).sum().item() == n*n + # F.prefetch_tensor(A) + # F.prefetch_tensor(B) + + + # F.fill(B2, 17.0) + # F._mul(A, B2) + + # F.prefetch_tensor(A, to_cpu=True) + # F.prefetch_tensor(B, to_cpu=True) + # F.prefetch_tensor(B2, to_cpu=True) + # torch.cuda.synchronize() + + # assert (A==17).sum().item() == n*n + + # torch.testing.assert_allclose(A, torch.ones(A.shape)*289)