diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index d7e186f..65eccf2 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -130,11 +130,17 @@ class Cusparse_Context(object): return cls._instance -def create_linear_map(signed=True): - if signed: - return torch.linspace(-1.0, 1.0, 256) +def create_linear_map(signed=True, bits=8): + sign = (-1.0 if signed else 0.0) + + values = torch.linspace(sign, 1.0, 2**bits) + gap = 256 - values.numel() + if gap == 0: + return values else: - return torch.linspace(0.0, 1.0, 256) + l = values.numel()//2 + #return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist()) + return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist()) def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2): diff --git a/tests/test_functional.py b/tests/test_functional.py index 329b270..494bf51 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2091,3 +2091,53 @@ def test_fp8_quant(): print(3, sum(abserr)/len(abserr)) print(3, sum(relerr)/len(relerr)) + +def test_few_bit_quant(): + + for bits in range(2, 9): + code = F.create_linear_map(True, bits=bits).cuda() + assert code.numel() == 256 + print(bits) + for i in range(100): + + values = torch.randn(1, 24, device='cuda') + values /= values.abs().max() + #values[values.abs() < 1e-6] += 1e-5 + + q1 = [] + v1 = [] + for v in values[0]: + idx = torch.abs(v-code).argmin() + q1.append(idx.item()) + v1.append(code[idx].item()) + + q1 = torch.Tensor(q1).cuda() + v1 = torch.Tensor(v1).cuda() + + q2, S2 = F.quantize(values, code=code) + v2 = F.dequantize(q2, S2) + + idx = torch.isclose(q1.int(), q2.int()) + if idx.sum(): + # some weird cases + err1 = torch.abs(v1-values).mean() + err2 = torch.abs(v2-values).mean() + assert err2 <= err1 + + else: + torch.testing.assert_allclose(q1, q2) + + #print(e_bits, p_bits) + #abserr = [] + #relerr = [] + #for i in range(100): + # A1 = torch.randn(1024, 1024, device="cuda") + # C, SC = F.quantize_blockwise(A1, code=code) + # A2 = F.dequantize_blockwise(C, SC) + # diff = torch.abs(A1 - A2) + # reldiff = diff/torch.abs(A1+1e-8) + # abserr.append(diff.mean().item()) + # relerr.append(reldiff.mean().item()) + # #assert diff < 0.0075 + #print(sum(abserr)/len(abserr)) + #print(sum(relerr)/len(relerr))