forked from mrq/bitsandbytes-rocm
Added FP8 quantization map.
This commit is contained in:
parent
8d87c0b852
commit
1efb87d89d
|
@ -6,6 +6,7 @@ import ctypes as ct
|
||||||
import operator
|
import operator
|
||||||
import random
|
import random
|
||||||
import torch
|
import torch
|
||||||
|
import itertools
|
||||||
|
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
@ -136,6 +137,39 @@ def create_linear_map(signed=True):
|
||||||
return torch.linspace(0.0, 1.0, 256)
|
return torch.linspace(0.0, 1.0, 256)
|
||||||
|
|
||||||
|
|
||||||
|
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
|
||||||
|
e = exponent_bits
|
||||||
|
p = precision_bits
|
||||||
|
assert e+p == 7
|
||||||
|
# the exponent is biased to 2^(e-1) -1 == 0
|
||||||
|
evalues = []
|
||||||
|
pvalues = []
|
||||||
|
for i, val in enumerate(range(-((2**(exponent_bits-1))), 2**(exponent_bits-1), 1)):
|
||||||
|
evalues.append(2**val)
|
||||||
|
|
||||||
|
|
||||||
|
lst = list(itertools.product([0, 1], repeat=precision_bits))
|
||||||
|
for bit_pattern in lst:
|
||||||
|
value = 1
|
||||||
|
for i, pval in enumerate(list(bit_pattern)):
|
||||||
|
value += pval*(2**-(i+1))
|
||||||
|
pvalues.append(value)
|
||||||
|
|
||||||
|
assert len(evalues)*len(pvalues) == 128
|
||||||
|
values = []
|
||||||
|
for ev in evalues:
|
||||||
|
for pv in pvalues:
|
||||||
|
values.append(-ev*pv)
|
||||||
|
values.append(ev*pv)
|
||||||
|
values.sort()
|
||||||
|
code = torch.Tensor(values)
|
||||||
|
code /= code.max()
|
||||||
|
code[127] = 0
|
||||||
|
|
||||||
|
return code
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def create_dynamic_map(signed=True, n=7):
|
def create_dynamic_map(signed=True, n=7):
|
||||||
"""
|
"""
|
||||||
Creates the dynamic quantiztion map.
|
Creates the dynamic quantiztion map.
|
||||||
|
|
|
@ -2040,3 +2040,54 @@ def test_blockwise_cpu_large():
|
||||||
assert diffs[-1] < 0.011
|
assert diffs[-1] < 0.011
|
||||||
# print(sum(diffs)/len(diffs))
|
# print(sum(diffs)/len(diffs))
|
||||||
# print(sum(reldiffs)/len(reldiffs))
|
# print(sum(reldiffs)/len(reldiffs))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_fp8_quant():
|
||||||
|
for e_bits in range(1, 7):
|
||||||
|
p_bits = 7-e_bits
|
||||||
|
code = F.create_fp8_map(True, e_bits, p_bits).cuda()
|
||||||
|
|
||||||
|
print(e_bits, p_bits)
|
||||||
|
abserr = []
|
||||||
|
relerr = []
|
||||||
|
for i in range(100):
|
||||||
|
A1 = torch.randn(1024, 1024, device="cuda")
|
||||||
|
C, SC = F.quantize_blockwise(A1, code=code)
|
||||||
|
A2 = F.dequantize_blockwise(C, SC)
|
||||||
|
diff = torch.abs(A1 - A2)
|
||||||
|
reldiff = diff/torch.abs(A1+1e-8)
|
||||||
|
abserr.append(diff.mean().item())
|
||||||
|
relerr.append(reldiff.mean().item())
|
||||||
|
#assert diff < 0.0075
|
||||||
|
print(sum(abserr)/len(abserr))
|
||||||
|
print(sum(relerr)/len(relerr))
|
||||||
|
|
||||||
|
abserr = []
|
||||||
|
relerr = []
|
||||||
|
for i in range(100):
|
||||||
|
A1 = torch.rand(1024, 1024, device="cuda")
|
||||||
|
C, SC = F.quantize_blockwise(A1, code=code)
|
||||||
|
A2 = F.dequantize_blockwise(C, SC)
|
||||||
|
diff = torch.abs(A1 - A2)
|
||||||
|
reldiff = diff/torch.abs(A1+1e-8)
|
||||||
|
abserr.append(diff.mean().item())
|
||||||
|
relerr.append(reldiff.mean().item())
|
||||||
|
#assert diff < 0.0075
|
||||||
|
print(sum(abserr)/len(abserr))
|
||||||
|
print(sum(relerr)/len(relerr))
|
||||||
|
|
||||||
|
abserr = []
|
||||||
|
relerr = []
|
||||||
|
for i in range(100):
|
||||||
|
A1 = torch.randn(1024, 1024, device="cuda")
|
||||||
|
C, SC = F.quantize_blockwise(A1)
|
||||||
|
A2 = F.dequantize_blockwise(C, SC)
|
||||||
|
diff = torch.abs(A1 - A2)
|
||||||
|
reldiff = diff/torch.abs(A1+1e-8)
|
||||||
|
abserr.append(diff.mean().item())
|
||||||
|
relerr.append(reldiff.mean().item())
|
||||||
|
#assert diff < 0.0075
|
||||||
|
print(3, sum(abserr)/len(abserr))
|
||||||
|
print(3, sum(relerr)/len(relerr))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user