Added k-bit fp8 map.
This commit is contained in:
parent
caf1832526
commit
98cbc4bc4f
|
@ -143,14 +143,15 @@ def create_linear_map(signed=True, bits=8):
|
||||||
return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist())
|
return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist())
|
||||||
|
|
||||||
|
|
||||||
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
|
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8):
|
||||||
e = exponent_bits
|
e = exponent_bits
|
||||||
p = precision_bits
|
p = precision_bits
|
||||||
assert e+p == 7
|
has_sign = 1 if signed else 0
|
||||||
|
assert e+p == total_bits-has_sign
|
||||||
# the exponent is biased to 2^(e-1) -1 == 0
|
# the exponent is biased to 2^(e-1) -1 == 0
|
||||||
evalues = []
|
evalues = []
|
||||||
pvalues = []
|
pvalues = []
|
||||||
for i, val in enumerate(range(-((2**(exponent_bits-1))), 2**(exponent_bits-1), 1)):
|
for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)):
|
||||||
evalues.append(2**val)
|
evalues.append(2**val)
|
||||||
|
|
||||||
|
|
||||||
|
@ -161,12 +162,17 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
|
||||||
value += pval*(2**-(i+1))
|
value += pval*(2**-(i+1))
|
||||||
pvalues.append(value)
|
pvalues.append(value)
|
||||||
|
|
||||||
assert len(evalues)*len(pvalues) == 128
|
assert len(evalues)*len(pvalues) == 2**(total_bits-has_sign)
|
||||||
values = []
|
values = []
|
||||||
for ev in evalues:
|
for ev in evalues:
|
||||||
for pv in pvalues:
|
for pv in pvalues:
|
||||||
values.append(-ev*pv)
|
if signed:
|
||||||
|
values.append(-ev*pv)
|
||||||
values.append(ev*pv)
|
values.append(ev*pv)
|
||||||
|
if total_bits < 8:
|
||||||
|
gap = 256 - len(values)
|
||||||
|
for i in range(gap):
|
||||||
|
values.append(0)
|
||||||
values.sort()
|
values.sort()
|
||||||
code = torch.Tensor(values)
|
code = torch.Tensor(values)
|
||||||
code /= code.max()
|
code /= code.max()
|
||||||
|
|
|
@ -11,7 +11,7 @@ import bitsandbytes as bnb
|
||||||
from bitsandbytes import functional as F
|
from bitsandbytes import functional as F
|
||||||
|
|
||||||
torch.set_printoptions(
|
torch.set_printoptions(
|
||||||
precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
|
precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
|
||||||
)
|
)
|
||||||
k = 20
|
k = 20
|
||||||
|
|
||||||
|
@ -2095,49 +2095,43 @@ def test_fp8_quant():
|
||||||
def test_few_bit_quant():
|
def test_few_bit_quant():
|
||||||
|
|
||||||
for bits in range(2, 9):
|
for bits in range(2, 9):
|
||||||
code = F.create_linear_map(True, bits=bits).cuda()
|
for method in ['linear', 'fp8']:
|
||||||
assert code.numel() == 256
|
code = None
|
||||||
print(bits)
|
if method == 'linear':
|
||||||
for i in range(100):
|
code = F.create_linear_map(True, bits=bits).cuda()
|
||||||
|
elif method == 'fp8':
|
||||||
|
ebits = math.ceil(bits/2)
|
||||||
|
pbits = bits-ebits-1
|
||||||
|
code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
|
||||||
|
print(ebits, pbits, bits)
|
||||||
|
print(code)
|
||||||
|
assert code.numel() == 256
|
||||||
|
print(bits)
|
||||||
|
for i in range(10):
|
||||||
|
|
||||||
values = torch.randn(1, 24, device='cuda')
|
values = torch.randn(1, 32, device='cuda')
|
||||||
values /= values.abs().max()
|
values /= values.abs().max()
|
||||||
#values[values.abs() < 1e-6] += 1e-5
|
#values[values.abs() < 1e-6] += 1e-5
|
||||||
|
|
||||||
q1 = []
|
q1 = []
|
||||||
v1 = []
|
v1 = []
|
||||||
for v in values[0]:
|
for v in values[0]:
|
||||||
idx = torch.abs(v-code).argmin()
|
idx = torch.abs(v-code).argmin()
|
||||||
q1.append(idx.item())
|
q1.append(idx.item())
|
||||||
v1.append(code[idx].item())
|
v1.append(code[idx].item())
|
||||||
|
|
||||||
q1 = torch.Tensor(q1).cuda()
|
q1 = torch.Tensor(q1).cuda()
|
||||||
v1 = torch.Tensor(v1).cuda()
|
v1 = torch.Tensor(v1).cuda()
|
||||||
|
|
||||||
q2, S2 = F.quantize(values, code=code)
|
q2, S2 = F.quantize(values, code=code)
|
||||||
v2 = F.dequantize(q2, S2)
|
v2 = F.dequantize(q2, S2)
|
||||||
|
|
||||||
idx = torch.isclose(q1.int(), q2.int())
|
idx = torch.isclose(q1.int(), q2.int())
|
||||||
if idx.sum():
|
if idx.sum():
|
||||||
# some weird cases
|
# some weird cases
|
||||||
err1 = torch.abs(v1-values).mean()
|
err1 = torch.abs(v1-values).mean()
|
||||||
err2 = torch.abs(v2-values).mean()
|
err2 = torch.abs(v2-values).mean()
|
||||||
assert err2 <= err1
|
assert err2 <= err1
|
||||||
|
|
||||||
else:
|
else:
|
||||||
torch.testing.assert_allclose(q1, q2)
|
torch.testing.assert_allclose(q1, q2)
|
||||||
|
|
||||||
#print(e_bits, p_bits)
|
|
||||||
#abserr = []
|
|
||||||
#relerr = []
|
|
||||||
#for i in range(100):
|
|
||||||
# A1 = torch.randn(1024, 1024, device="cuda")
|
|
||||||
# C, SC = F.quantize_blockwise(A1, code=code)
|
|
||||||
# A2 = F.dequantize_blockwise(C, SC)
|
|
||||||
# diff = torch.abs(A1 - A2)
|
|
||||||
# reldiff = diff/torch.abs(A1+1e-8)
|
|
||||||
# abserr.append(diff.mean().item())
|
|
||||||
# relerr.append(reldiff.mean().item())
|
|
||||||
# #assert diff < 0.0075
|
|
||||||
#print(sum(abserr)/len(abserr))
|
|
||||||
#print(sum(relerr)/len(relerr))
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user