Added k-bit linear quantization.
This commit is contained in:
parent
1efb87d89d
commit
caf1832526
|
@ -130,11 +130,17 @@ class Cusparse_Context(object):
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
|
|
||||||
def create_linear_map(signed=True):
|
def create_linear_map(signed=True, bits=8):
|
||||||
if signed:
|
sign = (-1.0 if signed else 0.0)
|
||||||
return torch.linspace(-1.0, 1.0, 256)
|
|
||||||
|
values = torch.linspace(sign, 1.0, 2**bits)
|
||||||
|
gap = 256 - values.numel()
|
||||||
|
if gap == 0:
|
||||||
|
return values
|
||||||
else:
|
else:
|
||||||
return torch.linspace(0.0, 1.0, 256)
|
l = values.numel()//2
|
||||||
|
#return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist())
|
||||||
|
return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist())
|
||||||
|
|
||||||
|
|
||||||
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
|
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
|
||||||
|
|
|
@ -2091,3 +2091,53 @@ def test_fp8_quant():
|
||||||
print(3, sum(abserr)/len(abserr))
|
print(3, sum(abserr)/len(abserr))
|
||||||
print(3, sum(relerr)/len(relerr))
|
print(3, sum(relerr)/len(relerr))
|
||||||
|
|
||||||
|
|
||||||
|
def test_few_bit_quant():
|
||||||
|
|
||||||
|
for bits in range(2, 9):
|
||||||
|
code = F.create_linear_map(True, bits=bits).cuda()
|
||||||
|
assert code.numel() == 256
|
||||||
|
print(bits)
|
||||||
|
for i in range(100):
|
||||||
|
|
||||||
|
values = torch.randn(1, 24, device='cuda')
|
||||||
|
values /= values.abs().max()
|
||||||
|
#values[values.abs() < 1e-6] += 1e-5
|
||||||
|
|
||||||
|
q1 = []
|
||||||
|
v1 = []
|
||||||
|
for v in values[0]:
|
||||||
|
idx = torch.abs(v-code).argmin()
|
||||||
|
q1.append(idx.item())
|
||||||
|
v1.append(code[idx].item())
|
||||||
|
|
||||||
|
q1 = torch.Tensor(q1).cuda()
|
||||||
|
v1 = torch.Tensor(v1).cuda()
|
||||||
|
|
||||||
|
q2, S2 = F.quantize(values, code=code)
|
||||||
|
v2 = F.dequantize(q2, S2)
|
||||||
|
|
||||||
|
idx = torch.isclose(q1.int(), q2.int())
|
||||||
|
if idx.sum():
|
||||||
|
# some weird cases
|
||||||
|
err1 = torch.abs(v1-values).mean()
|
||||||
|
err2 = torch.abs(v2-values).mean()
|
||||||
|
assert err2 <= err1
|
||||||
|
|
||||||
|
else:
|
||||||
|
torch.testing.assert_allclose(q1, q2)
|
||||||
|
|
||||||
|
#print(e_bits, p_bits)
|
||||||
|
#abserr = []
|
||||||
|
#relerr = []
|
||||||
|
#for i in range(100):
|
||||||
|
# A1 = torch.randn(1024, 1024, device="cuda")
|
||||||
|
# C, SC = F.quantize_blockwise(A1, code=code)
|
||||||
|
# A2 = F.dequantize_blockwise(C, SC)
|
||||||
|
# diff = torch.abs(A1 - A2)
|
||||||
|
# reldiff = diff/torch.abs(A1+1e-8)
|
||||||
|
# abserr.append(diff.mean().item())
|
||||||
|
# relerr.append(reldiff.mean().item())
|
||||||
|
# #assert diff < 0.0075
|
||||||
|
#print(sum(abserr)/len(abserr))
|
||||||
|
#print(sum(relerr)/len(relerr))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user