Added FP8 quantization map.

This commit is contained in:
Tim Dettmers 2022-11-03 19:49:50 -07:00
parent 8d87c0b852
commit 1efb87d89d
2 changed files with 85 additions and 0 deletions

View File

@ -6,6 +6,7 @@ import ctypes as ct
import operator import operator
import random import random
import torch import torch
import itertools
from typing import Tuple from typing import Tuple
from torch import Tensor from torch import Tensor
@ -136,6 +137,39 @@ def create_linear_map(signed=True):
return torch.linspace(0.0, 1.0, 256) return torch.linspace(0.0, 1.0, 256)
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
e = exponent_bits
p = precision_bits
assert e+p == 7
# the exponent is biased to 2^(e-1) -1 == 0
evalues = []
pvalues = []
for i, val in enumerate(range(-((2**(exponent_bits-1))), 2**(exponent_bits-1), 1)):
evalues.append(2**val)
lst = list(itertools.product([0, 1], repeat=precision_bits))
for bit_pattern in lst:
value = 1
for i, pval in enumerate(list(bit_pattern)):
value += pval*(2**-(i+1))
pvalues.append(value)
assert len(evalues)*len(pvalues) == 128
values = []
for ev in evalues:
for pv in pvalues:
values.append(-ev*pv)
values.append(ev*pv)
values.sort()
code = torch.Tensor(values)
code /= code.max()
code[127] = 0
return code
def create_dynamic_map(signed=True, n=7): def create_dynamic_map(signed=True, n=7):
""" """
Creates the dynamic quantiztion map. Creates the dynamic quantiztion map.

View File

@ -2040,3 +2040,54 @@ def test_blockwise_cpu_large():
assert diffs[-1] < 0.011 assert diffs[-1] < 0.011
# print(sum(diffs)/len(diffs)) # print(sum(diffs)/len(diffs))
# print(sum(reldiffs)/len(reldiffs)) # print(sum(reldiffs)/len(reldiffs))
def test_fp8_quant():
for e_bits in range(1, 7):
p_bits = 7-e_bits
code = F.create_fp8_map(True, e_bits, p_bits).cuda()
print(e_bits, p_bits)
abserr = []
relerr = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
#assert diff < 0.0075
print(sum(abserr)/len(abserr))
print(sum(relerr)/len(relerr))
abserr = []
relerr = []
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
#assert diff < 0.0075
print(sum(abserr)/len(abserr))
print(sum(relerr)/len(relerr))
abserr = []
relerr = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
#assert diff < 0.0075
print(3, sum(abserr)/len(abserr))
print(3, sum(relerr)/len(relerr))