Added FP8 quantization map.

This commit is contained in:
Tim Dettmers 2022-11-03 19:49:50 -07:00
parent 8d87c0b852
commit 1efb87d89d
2 changed files with 85 additions and 0 deletions

View File

@ -6,6 +6,7 @@ import ctypes as ct
import operator
import random
import torch
import itertools
from typing import Tuple
from torch import Tensor
@ -136,6 +137,39 @@ def create_linear_map(signed=True):
return torch.linspace(0.0, 1.0, 256)
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
e = exponent_bits
p = precision_bits
assert e+p == 7
# the exponent is biased to 2^(e-1) -1 == 0
evalues = []
pvalues = []
for i, val in enumerate(range(-((2**(exponent_bits-1))), 2**(exponent_bits-1), 1)):
evalues.append(2**val)
lst = list(itertools.product([0, 1], repeat=precision_bits))
for bit_pattern in lst:
value = 1
for i, pval in enumerate(list(bit_pattern)):
value += pval*(2**-(i+1))
pvalues.append(value)
assert len(evalues)*len(pvalues) == 128
values = []
for ev in evalues:
for pv in pvalues:
values.append(-ev*pv)
values.append(ev*pv)
values.sort()
code = torch.Tensor(values)
code /= code.max()
code[127] = 0
return code
def create_dynamic_map(signed=True, n=7):
"""
Creates the dynamic quantiztion map.

View File

@ -2040,3 +2040,54 @@ def test_blockwise_cpu_large():
assert diffs[-1] < 0.011
# print(sum(diffs)/len(diffs))
# print(sum(reldiffs)/len(reldiffs))
def test_fp8_quant():
for e_bits in range(1, 7):
p_bits = 7-e_bits
code = F.create_fp8_map(True, e_bits, p_bits).cuda()
print(e_bits, p_bits)
abserr = []
relerr = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
#assert diff < 0.0075
print(sum(abserr)/len(abserr))
print(sum(relerr)/len(relerr))
abserr = []
relerr = []
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
#assert diff < 0.0075
print(sum(abserr)/len(abserr))
print(sum(relerr)/len(relerr))
abserr = []
relerr = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
#assert diff < 0.0075
print(3, sum(abserr)/len(abserr))
print(3, sum(relerr)/len(relerr))