From 98cbc4bc4f15f5c094cd8575ddb0380a19516099 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 6 Nov 2022 11:59:37 -0800 Subject: [PATCH] Added k-bit fp8 map. --- bitsandbytes/functional.py | 16 +++++--- tests/test_functional.py | 76 ++++++++++++++++++-------------------- 2 files changed, 46 insertions(+), 46 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 65eccf2..ff48b7f 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -143,14 +143,15 @@ def create_linear_map(signed=True, bits=8): return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist()) -def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2): +def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8): e = exponent_bits p = precision_bits - assert e+p == 7 + has_sign = 1 if signed else 0 + assert e+p == total_bits-has_sign # the exponent is biased to 2^(e-1) -1 == 0 evalues = [] pvalues = [] - for i, val in enumerate(range(-((2**(exponent_bits-1))), 2**(exponent_bits-1), 1)): + for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)): evalues.append(2**val) @@ -161,12 +162,17 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2): value += pval*(2**-(i+1)) pvalues.append(value) - assert len(evalues)*len(pvalues) == 128 + assert len(evalues)*len(pvalues) == 2**(total_bits-has_sign) values = [] for ev in evalues: for pv in pvalues: - values.append(-ev*pv) + if signed: + values.append(-ev*pv) values.append(ev*pv) + if total_bits < 8: + gap = 256 - len(values) + for i in range(gap): + values.append(0) values.sort() code = torch.Tensor(values) code /= code.max() diff --git a/tests/test_functional.py b/tests/test_functional.py index 494bf51..bd4dafe 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -11,7 +11,7 @@ import bitsandbytes as bnb from bitsandbytes import functional as F torch.set_printoptions( - precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000 + precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000 ) k = 20 @@ -2095,49 +2095,43 @@ def test_fp8_quant(): def test_few_bit_quant(): for bits in range(2, 9): - code = F.create_linear_map(True, bits=bits).cuda() - assert code.numel() == 256 - print(bits) - for i in range(100): + for method in ['linear', 'fp8']: + code = None + if method == 'linear': + code = F.create_linear_map(True, bits=bits).cuda() + elif method == 'fp8': + ebits = math.ceil(bits/2) + pbits = bits-ebits-1 + code = F.create_fp8_map(True, ebits, pbits, bits).cuda() + print(ebits, pbits, bits) + print(code) + assert code.numel() == 256 + print(bits) + for i in range(10): - values = torch.randn(1, 24, device='cuda') - values /= values.abs().max() - #values[values.abs() < 1e-6] += 1e-5 + values = torch.randn(1, 32, device='cuda') + values /= values.abs().max() + #values[values.abs() < 1e-6] += 1e-5 - q1 = [] - v1 = [] - for v in values[0]: - idx = torch.abs(v-code).argmin() - q1.append(idx.item()) - v1.append(code[idx].item()) + q1 = [] + v1 = [] + for v in values[0]: + idx = torch.abs(v-code).argmin() + q1.append(idx.item()) + v1.append(code[idx].item()) - q1 = torch.Tensor(q1).cuda() - v1 = torch.Tensor(v1).cuda() + q1 = torch.Tensor(q1).cuda() + v1 = torch.Tensor(v1).cuda() - q2, S2 = F.quantize(values, code=code) - v2 = F.dequantize(q2, S2) + q2, S2 = F.quantize(values, code=code) + v2 = F.dequantize(q2, S2) - idx = torch.isclose(q1.int(), q2.int()) - if idx.sum(): - # some weird cases - err1 = torch.abs(v1-values).mean() - err2 = torch.abs(v2-values).mean() - assert err2 <= err1 + idx = torch.isclose(q1.int(), q2.int()) + if idx.sum(): + # some weird cases + err1 = torch.abs(v1-values).mean() + err2 = torch.abs(v2-values).mean() + assert err2 <= err1 - else: - torch.testing.assert_allclose(q1, q2) - - #print(e_bits, p_bits) - #abserr = [] - #relerr = [] - #for i in range(100): - # A1 = torch.randn(1024, 1024, device="cuda") - # C, SC = F.quantize_blockwise(A1, code=code) - # A2 = F.dequantize_blockwise(C, SC) - # diff = torch.abs(A1 - A2) - # reldiff = diff/torch.abs(A1+1e-8) - # abserr.append(diff.mean().item()) - # relerr.append(reldiff.mean().item()) - # #assert diff < 0.0075 - #print(sum(abserr)/len(abserr)) - #print(sum(relerr)/len(relerr)) + else: + torch.testing.assert_allclose(q1, q2)