forked from mrq/bitsandbytes-rocm
Merge branch 'TimDettmers:main' into memory-efficient-backward
This commit is contained in:
commit
3634fc738b
16
CHANGELOG.md
16
CHANGELOG.md
|
@ -90,3 +90,19 @@ Features:
|
|||
Bug fixes:
|
||||
- Now throws and error if LLM.int8() is used on a GPU that is not supported.
|
||||
- Enhances error messaging if CUDA SETUP fails.
|
||||
|
||||
|
||||
### 0.33.0
|
||||
|
||||
#### Various bug fixes
|
||||
|
||||
Features:
|
||||
- CPU quantization now supports a variable `blocksize` variable to enhance quantization speed or precision.
|
||||
|
||||
Bug fixes:
|
||||
- fixed an issue in CPU quantization where tensors with more than 2^31 elements would fail 19a7adca7a6c9bf7061a384d7e9d9b13676a1a88
|
||||
- fixed a bug where cpu binaries would fail if no GPU would be detected eab4d8232d558f2e6bd7f7cc3d00e2e6e94f4e80
|
||||
- fixed an issue where cpu binaries cause additional stdout messages 92a3363096e10ad6a5c4e944af898bd1186d806a
|
||||
- fixed an import of bnb.utils 2e630b55f51d454f3bd723dffda68a07ef93190c
|
||||
|
||||
We thank @mryab, @mbrukman, @chessgecko, @dbaranchuk for pull request with bug fixes and new features.
|
||||
|
|
|
@ -369,13 +369,7 @@ def estimate_quantiles(
|
|||
return out
|
||||
|
||||
|
||||
def quantize_blockwise(
|
||||
A: Tensor,
|
||||
code: Tensor = None,
|
||||
absmax: Tensor = None,
|
||||
rand=None,
|
||||
out: Tensor = None,
|
||||
) -> Tensor:
|
||||
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None, blocksize=4096) -> Tensor:
|
||||
"""
|
||||
Quantize tensor A in blocks of size 4096 values.
|
||||
|
||||
|
@ -412,9 +406,9 @@ def quantize_blockwise(
|
|||
|
||||
if absmax is None:
|
||||
n = A.numel()
|
||||
num_blocks = 4096
|
||||
blocks = n // num_blocks
|
||||
blocks += 1 if n % num_blocks > 0 else 0
|
||||
blocksize = (blocksize if A.device.type == 'cpu' else 4096)
|
||||
blocks = n // blocksize
|
||||
blocks += 1 if n % blocksize > 0 else 0
|
||||
absmax = torch.zeros((blocks,), device=A.device)
|
||||
|
||||
if out is None:
|
||||
|
@ -426,46 +420,18 @@ def quantize_blockwise(
|
|||
assert rand.numel() >= 1024
|
||||
rand_offset = random.randint(0, 1023)
|
||||
if A.dtype == torch.float32:
|
||||
lib.cquantize_blockwise_stochastic_fp32(
|
||||
get_ptr(code),
|
||||
get_ptr(A),
|
||||
get_ptr(absmax),
|
||||
get_ptr(out),
|
||||
get_ptr(rand),
|
||||
ct.c_int32(rand_offset),
|
||||
ct.c_int(A.numel()),
|
||||
)
|
||||
lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
|
||||
elif A.dtype == torch.float16:
|
||||
lib.cquantize_blockwise_stochastic_fp16(
|
||||
get_ptr(code),
|
||||
get_ptr(A),
|
||||
get_ptr(absmax),
|
||||
get_ptr(out),
|
||||
get_ptr(rand),
|
||||
ct.c_int32(rand_offset),
|
||||
ct.c_int(A.numel()),
|
||||
)
|
||||
lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
|
||||
)
|
||||
else:
|
||||
if A.dtype == torch.float32:
|
||||
lib.cquantize_blockwise_fp32(
|
||||
get_ptr(code),
|
||||
get_ptr(A),
|
||||
get_ptr(absmax),
|
||||
get_ptr(out),
|
||||
ct.c_int(A.numel()),
|
||||
)
|
||||
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
|
||||
elif A.dtype == torch.float16:
|
||||
lib.cquantize_blockwise_fp16(
|
||||
get_ptr(code),
|
||||
get_ptr(A),
|
||||
get_ptr(absmax),
|
||||
get_ptr(out),
|
||||
ct.c_int(A.numel()),
|
||||
)
|
||||
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
|
||||
|
@ -473,13 +439,7 @@ def quantize_blockwise(
|
|||
else:
|
||||
# cpu
|
||||
assert rand is None
|
||||
lib.cquantize_blockwise_cpu_fp32(
|
||||
get_ptr(code),
|
||||
get_ptr(A),
|
||||
get_ptr(absmax),
|
||||
get_ptr(out),
|
||||
ct.c_int(A.numel()),
|
||||
)
|
||||
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
|
||||
|
||||
return out, (absmax, code)
|
||||
|
||||
|
@ -529,43 +489,21 @@ def dequantize_blockwise(
|
|||
if quant_state is None:
|
||||
quant_state = (absmax, code)
|
||||
|
||||
if blocksize not in [2048, 4096]:
|
||||
raise ValueError(
|
||||
f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]"
|
||||
)
|
||||
|
||||
if A.device.type != 'cpu':
|
||||
if blocksize not in [2048, 4096]:
|
||||
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]")
|
||||
is_on_gpu([A, out])
|
||||
if out.dtype == torch.float32:
|
||||
lib.cdequantize_blockwise_fp32(
|
||||
get_ptr(quant_state[1]),
|
||||
get_ptr(A),
|
||||
get_ptr(quant_state[0]),
|
||||
get_ptr(out),
|
||||
ct.c_int(blocksize),
|
||||
ct.c_int(A.numel()),
|
||||
)
|
||||
lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
||||
elif out.dtype == torch.float16:
|
||||
lib.cdequantize_blockwise_fp16(
|
||||
get_ptr(quant_state[1]),
|
||||
get_ptr(A),
|
||||
get_ptr(quant_state[0]),
|
||||
get_ptr(out),
|
||||
ct.c_int(blocksize),
|
||||
ct.c_int(A.numel()),
|
||||
)
|
||||
lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
|
||||
)
|
||||
else:
|
||||
lib.cdequantize_blockwise_cpu_fp32(
|
||||
get_ptr(quant_state[1]),
|
||||
get_ptr(A),
|
||||
get_ptr(quant_state[0]),
|
||||
get_ptr(out),
|
||||
ct.c_int(A.numel()),
|
||||
)
|
||||
lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
|
||||
|
||||
return out
|
||||
|
||||
|
|
|
@ -12,16 +12,16 @@ void *quantize_block(void *arguments) {
|
|||
|
||||
// 1. find absmax in block
|
||||
float absmax_block = -FLT_MAX;
|
||||
for (int i = args->block_idx; i < args->block_end; i++)
|
||||
for (long long i = args->block_idx; i < args->block_end; i++)
|
||||
absmax_block = fmax(absmax_block, fabs(args->A[i]));
|
||||
|
||||
args->absmax[args->block_idx / BLOCK_SIZE] = absmax_block;
|
||||
args->absmax[args->block_idx / args->blocksize] = absmax_block;
|
||||
|
||||
for (int i = args->block_idx; i < args->block_end; i++) {
|
||||
for (long long i = args->block_idx; i < args->block_end; i++) {
|
||||
// 2. divide input value by absmax to normalize into [-1.0, 1.0]
|
||||
// 3. do binary search to find the closest value
|
||||
float normed_value = args->A[i] / absmax_block;
|
||||
int idx = args->bin_searcher->scalar(normed_value);
|
||||
long long idx = args->bin_searcher->scalar(normed_value);
|
||||
|
||||
// 4. check minimal distance
|
||||
// The binary search returns always the value to the left, which might not be the closest value
|
||||
|
|
|
@ -5,18 +5,20 @@
|
|||
|
||||
using namespace BinSearch;
|
||||
|
||||
#define BLOCK_SIZE 16384
|
||||
|
||||
struct quantize_block_args {
|
||||
BinAlgo<Scalar, float, Direct2> *bin_searcher;
|
||||
float *code;
|
||||
float *A;
|
||||
float *absmax;
|
||||
unsigned char *out;
|
||||
int block_end;
|
||||
int block_idx;
|
||||
int threadidx;
|
||||
long long block_end;
|
||||
long long block_idx;
|
||||
long long threadidx;
|
||||
long long blocksize;
|
||||
};
|
||||
|
||||
#define BLOCK_SIZE 4096
|
||||
|
||||
void *quantize_block(void *arguments);
|
||||
|
||||
|
|
|
@ -4,37 +4,47 @@
|
|||
|
||||
using namespace BinSearch;
|
||||
|
||||
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n) {
|
||||
for (int block_idx = 0; block_idx < n; block_idx += BLOCK_SIZE) {
|
||||
int valid_items = n - block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
|
||||
int block_end = block_idx + valid_items;
|
||||
for (int i = block_idx; i < block_end; i++)
|
||||
out[i] = code[A[i]] * absmax[block_idx / BLOCK_SIZE];
|
||||
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n) {
|
||||
for (long long block_idx = 0; block_idx < n; block_idx += blocksize) {
|
||||
long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
|
||||
long long block_end = block_idx + valid_items;
|
||||
for (long long i = block_idx; i < block_end; i++)
|
||||
out[i] = code[A[i]] * absmax[block_idx / blocksize];
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n) {
|
||||
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n)
|
||||
{
|
||||
|
||||
// the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below
|
||||
code[0] = -1.0f;
|
||||
|
||||
int num_blocks = n / BLOCK_SIZE;
|
||||
num_blocks += n % BLOCK_SIZE == 0 ? 0 : 1;
|
||||
|
||||
pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * num_blocks);
|
||||
struct quantize_block_args **args = (quantize_block_args **) malloc(num_blocks * sizeof(quantize_block_args *));
|
||||
|
||||
for (int i = 0; i < num_blocks; i++)
|
||||
args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args));
|
||||
long long num_blocks = n / blocksize;
|
||||
num_blocks += n % blocksize == 0 ? 0 : 1;
|
||||
|
||||
const uint32 elements_code = 256;
|
||||
BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code);
|
||||
|
||||
for (int block_idx = 0; block_idx < n; block_idx += BLOCK_SIZE) {
|
||||
int valid_items = n - block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
|
||||
int block_end = block_idx + valid_items;
|
||||
int thread_wave_size = 256;
|
||||
// we chunk the thresds into waves of 256 since the max limit is
|
||||
// between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size)
|
||||
for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size)
|
||||
{
|
||||
long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset;
|
||||
pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * valid_chunks);
|
||||
|
||||
struct quantize_block_args *arg = args[block_idx / BLOCK_SIZE];
|
||||
struct quantize_block_args **args = (quantize_block_args **) malloc(valid_chunks * sizeof(quantize_block_args *));
|
||||
|
||||
for(long long i = 0; i < valid_chunks; i++)
|
||||
args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args));
|
||||
|
||||
int chunks_processed = 0;
|
||||
for(long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize)
|
||||
{
|
||||
long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
|
||||
long long block_end = block_idx + valid_items;
|
||||
|
||||
struct quantize_block_args *arg = args[chunks_processed];
|
||||
arg->bin_searcher = &bin_searcher;
|
||||
arg->code = code;
|
||||
arg->A = A;
|
||||
|
@ -42,16 +52,22 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int
|
|||
arg->out = out;
|
||||
arg->block_end = block_end;
|
||||
arg->block_idx = block_idx;
|
||||
arg->threadidx = block_idx / BLOCK_SIZE;
|
||||
arg->threadidx = block_idx / blocksize;
|
||||
arg->blocksize = blocksize;
|
||||
|
||||
pthread_create(&threads[block_idx / BLOCK_SIZE], NULL, &quantize_block, (void *) arg);
|
||||
pthread_create(&threads[chunks_processed], NULL, &quantize_block, (void *) arg);
|
||||
chunks_processed += 1;
|
||||
if(chunks_processed == valid_chunks){ break; }
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_blocks; i++)
|
||||
for (int i = 0; i < valid_chunks; i++)
|
||||
int err = pthread_join(threads[i], NULL);
|
||||
|
||||
free(threads);
|
||||
for (int i = 0; i < num_blocks; i++)
|
||||
for (int i = 0; i < valid_chunks; i++)
|
||||
free(args[i]);
|
||||
free(args);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
#ifndef BITSANDBYTES_CPU_OPS_H
|
||||
#define BITSANDBYTES_CPU_OPS_H
|
||||
|
||||
#include <iostream>
|
||||
#include <stdio.h>
|
||||
|
||||
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n);
|
||||
|
||||
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n);
|
||||
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n);
|
||||
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n);
|
||||
|
||||
#endif
|
||||
|
|
|
@ -287,7 +287,7 @@ extern "C"
|
|||
void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); }
|
||||
|
||||
#endif
|
||||
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); }
|
||||
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); }
|
||||
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); }
|
||||
}
|
||||
|
||||
|
|
2
setup.py
2
setup.py
|
@ -18,7 +18,7 @@ def read(fname):
|
|||
|
||||
setup(
|
||||
name=f"bitsandbytes",
|
||||
version=f"0.32.3",
|
||||
version=f"0.33.1",
|
||||
author="Tim Dettmers",
|
||||
author_email="dettmers@cs.washington.edu",
|
||||
description="8-bit optimizers and matrix multiplication routines.",
|
||||
|
|
|
@ -1815,14 +1815,14 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
|
|||
batch_size = 1
|
||||
seqdim = 1
|
||||
values = []
|
||||
#values.append((batch_size, seqdim, 768, 4 * 768))
|
||||
values.append((batch_size, seqdim, 768, 4 * 768))
|
||||
# values.append((batch_size, seqdim, 1024, 4*1024))
|
||||
# values.append((batch_size, seqdim, 1536, 4*1536))
|
||||
# values.append((batch_size, seqdim, 2048, 4*2048))
|
||||
# values.append((batch_size, seqdim, 2560, 4*2560))
|
||||
# values.append((batch_size, seqdim, 4096, 4*4096))
|
||||
# values.append((batch_size, seqdim, 5140, 4*5140))
|
||||
values.append((batch_size, seqdim, 12288, 4*12288))
|
||||
#values.append((batch_size, seqdim, 12288, 4*12288))
|
||||
names = [
|
||||
"batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
|
||||
]
|
||||
|
@ -2125,3 +2125,26 @@ def test_extract_outliers():
|
|||
assert outliers2.shape[1] == idx.numel()
|
||||
|
||||
torch.testing.assert_allclose(outliers1, outliers2)
|
||||
|
||||
|
||||
|
||||
def test_blockwise_cpu_large():
|
||||
diffs = []
|
||||
reldiffs = []
|
||||
batch = 128
|
||||
seq = 128
|
||||
for hidden in [128, 14336]:
|
||||
for blocksize in [4096, 16384]:
|
||||
for i in range(2):
|
||||
A1 = torch.randn(batch, seq, hidden, device='cpu')
|
||||
t0 = time.time()
|
||||
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
|
||||
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
|
||||
print(time.time() - t0)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff / torch.abs(A1 + 1e-8)
|
||||
diffs.append(diff.mean().item())
|
||||
reldiffs.append(reldiff.mean().item())
|
||||
assert diffs[-1] < 0.011
|
||||
# print(sum(diffs)/len(diffs))
|
||||
# print(sum(reldiffs)/len(reldiffs))
|
||||
|
|
Loading…
Reference in New Issue
Block a user