From 19a7adca7a6c9bf7061a384d7e9d9b13676a1a88 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 11 Sep 2022 11:55:09 -0700 Subject: [PATCH] Fixed 2^31 max size issue for cpu blockwise quant. --- bitsandbytes/functional.py | 90 ++++++-------------------------------- csrc/common.cpp | 8 ++-- csrc/common.h | 10 +++-- csrc/cpu_ops.cpp | 85 ++++++++++++++++++++--------------- csrc/cpu_ops.h | 7 +-- csrc/pythonInterface.c | 4 +- tests/test_functional.py | 27 +++++++++++- 7 files changed, 105 insertions(+), 126 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 22200f2..c104ebd 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -369,13 +369,7 @@ def estimate_quantiles( return out -def quantize_blockwise( - A: Tensor, - code: Tensor = None, - absmax: Tensor = None, - rand=None, - out: Tensor = None, -) -> Tensor: +def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None, blocksize=4096) -> Tensor: """ Quantize tensor A in blocks of size 4096 values. @@ -412,9 +406,9 @@ def quantize_blockwise( if absmax is None: n = A.numel() - num_blocks = 4096 - blocks = n // num_blocks - blocks += 1 if n % num_blocks > 0 else 0 + blocksize = (blocksize if A.device.type == 'cpu' else 4096) + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 absmax = torch.zeros((blocks,), device=A.device) if out is None: @@ -426,46 +420,18 @@ def quantize_blockwise( assert rand.numel() >= 1024 rand_offset = random.randint(0, 1023) if A.dtype == torch.float32: - lib.cquantize_blockwise_stochastic_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - get_ptr(rand), - ct.c_int32(rand_offset), - ct.c_int(A.numel()), - ) + lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel())) elif A.dtype == torch.float16: - lib.cquantize_blockwise_stochastic_fp16( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - get_ptr(rand), - ct.c_int32(rand_offset), - ct.c_int(A.numel()), - ) + lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel())) else: raise ValueError( f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" ) else: if A.dtype == torch.float32: - lib.cquantize_blockwise_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(A.numel()), - ) + lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel())) elif A.dtype == torch.float16: - lib.cquantize_blockwise_fp16( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(A.numel()), - ) + lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel())) else: raise ValueError( f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" @@ -473,13 +439,7 @@ def quantize_blockwise( else: # cpu assert rand is None - lib.cquantize_blockwise_cpu_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(A.numel()), - ) + lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) return out, (absmax, code) @@ -529,43 +489,21 @@ def dequantize_blockwise( if quant_state is None: quant_state = (absmax, code) - if blocksize not in [2048, 4096]: - raise ValueError( - f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]" - ) if A.device.type != 'cpu': + if blocksize not in [2048, 4096]: + raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]") is_on_gpu([A, out]) if out.dtype == torch.float32: - lib.cdequantize_blockwise_fp32( - get_ptr(quant_state[1]), - get_ptr(A), - get_ptr(quant_state[0]), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(A.numel()), - ) + lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) elif out.dtype == torch.float16: - lib.cdequantize_blockwise_fp16( - get_ptr(quant_state[1]), - get_ptr(A), - get_ptr(quant_state[0]), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(A.numel()), - ) + lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) else: raise ValueError( f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" ) else: - lib.cdequantize_blockwise_cpu_fp32( - get_ptr(quant_state[1]), - get_ptr(A), - get_ptr(quant_state[0]), - get_ptr(out), - ct.c_int(A.numel()), - ) + lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) return out diff --git a/csrc/common.cpp b/csrc/common.cpp index 972602b..52f0299 100644 --- a/csrc/common.cpp +++ b/csrc/common.cpp @@ -12,16 +12,16 @@ void *quantize_block(void *arguments) { // 1. find absmax in block float absmax_block = -FLT_MAX; - for (int i = args->block_idx; i < args->block_end; i++) + for (long long i = args->block_idx; i < args->block_end; i++) absmax_block = fmax(absmax_block, fabs(args->A[i])); - args->absmax[args->block_idx / BLOCK_SIZE] = absmax_block; + args->absmax[args->block_idx / args->blocksize] = absmax_block; - for (int i = args->block_idx; i < args->block_end; i++) { + for (long long i = args->block_idx; i < args->block_end; i++) { // 2. divide input value by absmax to normalize into [-1.0, 1.0] // 3. do binary search to find the closest value float normed_value = args->A[i] / absmax_block; - int idx = args->bin_searcher->scalar(normed_value); + long long idx = args->bin_searcher->scalar(normed_value); // 4. check minimal distance // The binary search returns always the value to the left, which might not be the closest value diff --git a/csrc/common.h b/csrc/common.h index 2f25a58..c99034e 100644 --- a/csrc/common.h +++ b/csrc/common.h @@ -5,18 +5,20 @@ using namespace BinSearch; +#define BLOCK_SIZE 16384 + struct quantize_block_args { BinAlgo *bin_searcher; float *code; float *A; float *absmax; unsigned char *out; - int block_end; - int block_idx; - int threadidx; + long long block_end; + long long block_idx; + long long threadidx; + long long blocksize; }; -#define BLOCK_SIZE 4096 void *quantize_block(void *arguments); diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 89de52d..303e8ed 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -4,54 +4,69 @@ using namespace BinSearch; -void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n) { - for (int block_idx = 0; block_idx < n; block_idx += BLOCK_SIZE) { - int valid_items = n - block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx; - int block_end = block_idx + valid_items; - for (int i = block_idx; i < block_end; i++) - out[i] = code[A[i]] * absmax[block_idx / BLOCK_SIZE]; +void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n) { + for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { + long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; + long long block_end = block_idx + valid_items; + for (long long i = block_idx; i < block_end; i++) + out[i] = code[A[i]] * absmax[block_idx / blocksize]; } } -void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n) { +void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n) +{ // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below code[0] = -1.0f; - int num_blocks = n / BLOCK_SIZE; - num_blocks += n % BLOCK_SIZE == 0 ? 0 : 1; - - pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * num_blocks); - struct quantize_block_args **args = (quantize_block_args **) malloc(num_blocks * sizeof(quantize_block_args *)); - - for (int i = 0; i < num_blocks; i++) - args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args)); + long long num_blocks = n / blocksize; + num_blocks += n % blocksize == 0 ? 0 : 1; const uint32 elements_code = 256; BinAlgo bin_searcher(code, elements_code); - for (int block_idx = 0; block_idx < n; block_idx += BLOCK_SIZE) { - int valid_items = n - block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx; - int block_end = block_idx + valid_items; + int thread_wave_size = 256; + // we chunk the thresds into waves of 256 since the max limit is + // between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size) + for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size) + { + pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * thread_wave_size); - struct quantize_block_args *arg = args[block_idx / BLOCK_SIZE]; - arg->bin_searcher = &bin_searcher; - arg->code = code; - arg->A = A; - arg->absmax = absmax; - arg->out = out; - arg->block_end = block_end; - arg->block_idx = block_idx; - arg->threadidx = block_idx / BLOCK_SIZE; + struct quantize_block_args **args = (quantize_block_args **) malloc(thread_wave_size * sizeof(quantize_block_args *)); + + for(long long i = 0; i < thread_wave_size; i++) + args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args)); + + int chunks_processed = 0; + for(long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize) + { + long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; + long long block_end = block_idx + valid_items; + + struct quantize_block_args *arg = args[chunks_processed]; + arg->bin_searcher = &bin_searcher; + arg->code = code; + arg->A = A; + arg->absmax = absmax; + arg->out = out; + arg->block_end = block_end; + arg->block_idx = block_idx; + arg->threadidx = block_idx / blocksize; + arg->blocksize = blocksize; + + pthread_create(&threads[chunks_processed], NULL, &quantize_block, (void *) arg); + chunks_processed += 1; + if(chunks_processed == thread_wave_size){ break; } + } + + for (int i = 0; i < thread_wave_size; i++) + int err = pthread_join(threads[i], NULL); + + free(threads); + for (int i = 0; i < thread_wave_size; i++) + free(args[i]); + free(args); - pthread_create(&threads[block_idx / BLOCK_SIZE], NULL, &quantize_block, (void *) arg); } - for (int i = 0; i < num_blocks; i++) - int err = pthread_join(threads[i], NULL); - - free(threads); - for (int i = 0; i < num_blocks; i++) - free(args[i]); - free(args); } diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 57145a9..2ddf81e 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -1,9 +1,10 @@ #ifndef BITSANDBYTES_CPU_OPS_H #define BITSANDBYTES_CPU_OPS_H +#include +#include -void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n); - -void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n); +void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n); +void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n); #endif diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 0707674..58e26a9 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -287,7 +287,7 @@ extern "C" void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); } #endif - void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); } - void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); } + void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } } diff --git a/tests/test_functional.py b/tests/test_functional.py index 14cc21e..d07affe 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1815,14 +1815,14 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): batch_size = 1 seqdim = 1 values = [] -#values.append((batch_size, seqdim, 768, 4 * 768)) +values.append((batch_size, seqdim, 768, 4 * 768)) # values.append((batch_size, seqdim, 1024, 4*1024)) # values.append((batch_size, seqdim, 1536, 4*1536)) # values.append((batch_size, seqdim, 2048, 4*2048)) # values.append((batch_size, seqdim, 2560, 4*2560)) # values.append((batch_size, seqdim, 4096, 4*4096)) # values.append((batch_size, seqdim, 5140, 4*5140)) -values.append((batch_size, seqdim, 12288, 4*12288)) +#values.append((batch_size, seqdim, 12288, 4*12288)) names = [ "batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values ] @@ -2125,3 +2125,26 @@ def test_extract_outliers(): assert outliers2.shape[1] == idx.numel() torch.testing.assert_allclose(outliers1, outliers2) + + + +def test_blockwise_cpu_large(): + diffs = [] + reldiffs = [] + batch = 128 + seq = 128 + hidden = 14336 + for blocksize in [4096, 16384]: + for i in range(2): + A1 = torch.randn(batch, seq, hidden, device='cpu') + t0 = time.time() + C, S = F.quantize_blockwise(A1, blocksize=blocksize) + A2 = F.dequantize_blockwise(C, S, blocksize=blocksize) + print(time.time() - t0) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + assert diffs[-1] < 0.011 + # print(sum(diffs)/len(diffs)) + # print(sum(reldiffs)/len(reldiffs))