From 0f9d30207f7a86c6be17f8fd897f0716db32cdfd Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 19 Apr 2023 11:48:47 -0700 Subject: [PATCH] Added nested quantization for blockwise quantization. --- bitsandbytes/functional.py | 25 +++++++++---- tests/test_functional.py | 72 ++++++++++++++++++++------------------ 2 files changed, 55 insertions(+), 42 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index ff0eb7e..eb49800 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -541,7 +541,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n return out -def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None, blocksize=4096) -> Tensor: +def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor: """ Quantize tensor A in blocks of size 4096 values. @@ -586,7 +586,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra out = torch.zeros_like(A, dtype=torch.uint8) if A.device.type != 'cpu': - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32] + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] cblocksize = ct.c_int32(blocksize) prev_device = pre_call(A.device) code = code.to(A.device) @@ -616,7 +616,15 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra assert rand is None lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) - state = [absmax, code, blocksize] + if nested: + offset = absmax.mean() + absmax -= offset + qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False) + state = [qabsmax, code, blocksize, nested, offset, state2] + else: + state = [absmax, code, blocksize, nested, None, None] + + return out, state @@ -628,6 +636,7 @@ def dequantize_blockwise( code: Tensor = None, out: Tensor = None, blocksize: int = 4096, + nested=False ) -> Tensor: """ Dequantizes blockwise quantized values. @@ -665,13 +674,15 @@ def dequantize_blockwise( if quant_state is None: quant_state = (absmax, code, blocksize) else: - absmax, code, blocksize = quant_state - + absmax, code, blocksize, nested, offset, state2 = quant_state + if nested: + absmax = dequantize_blockwise(absmax, state2) + absmax += offset if A.device.type != 'cpu': device = pre_call(A.device) code = code.to(A.device) - if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64, 32]: + if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: @@ -736,7 +747,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz if out is None: out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device) - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32] + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] prev_device = pre_call(A.device) is_on_gpu([A, out, absmax]) diff --git a/tests/test_functional.py b/tests/test_functional.py index 61ea712..82f6a71 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -150,42 +150,44 @@ def test_dynamic_quantization(): assert diff < 0.004 -def test_dynamic_blockwise_quantization(): - #print('') - for blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]: - diffs = [] - reldiffs = [] - for i in range(100): - A1 = torch.randn(1024, 1024, device="cuda") - C, S = F.quantize_blockwise(A1, blocksize=blocksize) - A2 = F.dequantize_blockwise(C, S, blocksize=blocksize) - diff = torch.abs(A1 - A2) - reldiff = diff / torch.abs(A1 + 1e-8) - diffs.append(diff.mean().item()) - reldiffs.append(reldiff.mean().item()) - abserr = sum(diffs)/len(diffs) - relerr = sum(reldiffs)/len(reldiffs) - assert abserr < 0.011 - assert relerr < 0.018 - #print('randn', blocksize, sum(diffs)/len(diffs)) - #print('randn', blocksize, sum(reldiffs)/len(reldiffs)) - diffs = [] - for i in range(100): - A1 = torch.rand(1024, 1024, device="cuda") - C, S = F.quantize_blockwise(A1, blocksize=blocksize) - A2 = F.dequantize_blockwise(C, S, blocksize=blocksize) - diff = torch.abs(A1 - A2) - reldiff = diff / torch.abs(A1 + 1e-8) - diffs.append(diff.mean().item()) - reldiffs.append(reldiff.mean().item()) - #torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) - abserr = sum(diffs)/len(diffs) - relerr = sum(reldiffs)/len(reldiffs) - assert abserr < 0.0035 - assert relerr < 0.015 - #print('rand', blocksize, sum(diffs)/len(diffs)) - #print('rand', blocksize, sum(reldiffs)/len(reldiffs)) +@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"]) +@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64]) +def test_dynamic_blockwise_quantization(nested, blocksize): + #print('') + diffs = [] + reldiffs = [] + for i in range(100): + A1 = torch.randn(1024, 1024, device="cuda") + C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested) + A2 = F.dequantize_blockwise(C, S) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + abserr = sum(diffs)/len(diffs) + relerr = sum(reldiffs)/len(reldiffs) + assert abserr < 0.011 + assert relerr < 0.018 + print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs)) + print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs)) + + diffs = [] + for i in range(100): + A1 = torch.rand(1024, 1024, device="cuda") + C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested) + A2 = F.dequantize_blockwise(C, S) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + #torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) + abserr = sum(diffs)/len(diffs) + relerr = sum(reldiffs)/len(reldiffs) + assert abserr < 0.0035 + assert relerr < 0.015 + print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs)) + print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) def test_dynamic_blockwise_stochastic_quantization():