diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index c104ebd..d7e186f 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -6,6 +6,7 @@ import ctypes as ct import operator import random import torch +import itertools from typing import Tuple from torch import Tensor @@ -136,6 +137,39 @@ def create_linear_map(signed=True): return torch.linspace(0.0, 1.0, 256) +def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2): + e = exponent_bits + p = precision_bits + assert e+p == 7 + # the exponent is biased to 2^(e-1) -1 == 0 + evalues = [] + pvalues = [] + for i, val in enumerate(range(-((2**(exponent_bits-1))), 2**(exponent_bits-1), 1)): + evalues.append(2**val) + + + lst = list(itertools.product([0, 1], repeat=precision_bits)) + for bit_pattern in lst: + value = 1 + for i, pval in enumerate(list(bit_pattern)): + value += pval*(2**-(i+1)) + pvalues.append(value) + + assert len(evalues)*len(pvalues) == 128 + values = [] + for ev in evalues: + for pv in pvalues: + values.append(-ev*pv) + values.append(ev*pv) + values.sort() + code = torch.Tensor(values) + code /= code.max() + code[127] = 0 + + return code + + + def create_dynamic_map(signed=True, n=7): """ Creates the dynamic quantiztion map. diff --git a/tests/test_functional.py b/tests/test_functional.py index cf26714..329b270 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2040,3 +2040,54 @@ def test_blockwise_cpu_large(): assert diffs[-1] < 0.011 # print(sum(diffs)/len(diffs)) # print(sum(reldiffs)/len(reldiffs)) + + + +def test_fp8_quant(): + for e_bits in range(1, 7): + p_bits = 7-e_bits + code = F.create_fp8_map(True, e_bits, p_bits).cuda() + + print(e_bits, p_bits) + abserr = [] + relerr = [] + for i in range(100): + A1 = torch.randn(1024, 1024, device="cuda") + C, SC = F.quantize_blockwise(A1, code=code) + A2 = F.dequantize_blockwise(C, SC) + diff = torch.abs(A1 - A2) + reldiff = diff/torch.abs(A1+1e-8) + abserr.append(diff.mean().item()) + relerr.append(reldiff.mean().item()) + #assert diff < 0.0075 + print(sum(abserr)/len(abserr)) + print(sum(relerr)/len(relerr)) + + abserr = [] + relerr = [] + for i in range(100): + A1 = torch.rand(1024, 1024, device="cuda") + C, SC = F.quantize_blockwise(A1, code=code) + A2 = F.dequantize_blockwise(C, SC) + diff = torch.abs(A1 - A2) + reldiff = diff/torch.abs(A1+1e-8) + abserr.append(diff.mean().item()) + relerr.append(reldiff.mean().item()) + #assert diff < 0.0075 + print(sum(abserr)/len(abserr)) + print(sum(relerr)/len(relerr)) + + abserr = [] + relerr = [] + for i in range(100): + A1 = torch.randn(1024, 1024, device="cuda") + C, SC = F.quantize_blockwise(A1) + A2 = F.dequantize_blockwise(C, SC) + diff = torch.abs(A1 - A2) + reldiff = diff/torch.abs(A1+1e-8) + abserr.append(diff.mean().item()) + relerr.append(reldiff.mean().item()) + #assert diff < 0.0075 + print(3, sum(abserr)/len(abserr)) + print(3, sum(relerr)/len(relerr)) +