Added k<256 quantile estimate.

This commit is contained in:
Tim Dettmers 2022-11-06 13:05:25 -08:00
parent 98cbc4bc4f
commit 2f2063bac2
2 changed files with 74 additions and 30 deletions

View File

@ -182,7 +182,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
def create_dynamic_map(signed=True, n=7): def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
""" """
Creates the dynamic quantiztion map. Creates the dynamic quantiztion map.
@ -203,28 +203,32 @@ def create_dynamic_map(signed=True, n=7):
# these are additional items that come from the case # these are additional items that come from the case
# where all the exponent bits are zero and no # where all the exponent bits are zero and no
# indicator bit is present # indicator bit is present
additional_items = 2 ** (7 - n) - 1 non_sign_bits = total_bits - (1 if signed else 0)
additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1
if not signed: if not signed:
additional_items = 2 * additional_items additional_items = 2 * additional_items
for i in range(n): for i in range(max_exponent_bits):
fraction_items = ( fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1))
2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1
)
boundaries = torch.linspace(0.1, 1, fraction_items) boundaries = torch.linspace(0.1, 1, fraction_items)
means = (boundaries[:-1] + boundaries[1:]) / 2.0 means = (boundaries[:-1] + boundaries[1:]) / 2.0
data += ((10 ** (-(n - 1) + i)) * means).tolist() data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
if signed: if signed:
data += (-(10 ** (-(n - 1) + i)) * means).tolist() data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
if additional_items > 0: if additional_items > 0:
boundaries = torch.linspace(0.1, 1, additional_items + 1) boundaries = torch.linspace(0.1, 1, additional_items + 1)
means = (boundaries[:-1] + boundaries[1:]) / 2.0 means = (boundaries[:-1] + boundaries[1:]) / 2.0
data += ((10 ** (-(n - 1) + i)) * means).tolist() data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
if signed: if signed:
data += (-(10 ** (-(n - 1) + i)) * means).tolist() data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
data.append(0) data.append(0)
data.append(1.0) data.append(1.0)
gap = 256 - len(data)
for i in range(gap):
data.append(0)
data.sort() data.sort()
return Tensor(data) return Tensor(data)
@ -371,9 +375,7 @@ def nvidia_transform(
return out, new_state return out, new_state
def estimate_quantiles( def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor:
A: Tensor, out: Tensor = None, offset: float = 1 / 512
) -> Tensor:
''' '''
Estimates 256 equidistant quantiles on the input tensor eCDF. Estimates 256 equidistant quantiles on the input tensor eCDF.
@ -393,25 +395,36 @@ def estimate_quantiles(
out : torch.Tensor out : torch.Tensor
Tensor with the 256 estimated quantiles. Tensor with the 256 estimated quantiles.
offset : float offset : float
The offset for the first and last quantile from 0 and 1. Default: 1/512 The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles)
num_quantiles : int
The number of equally spaced quantiles.
Returns Returns
------- -------
torch.Tensor: torch.Tensor:
The 256 quantiles in float32 datatype. The 256 quantiles in float32 datatype.
''' '''
if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.')
if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}")
if num_quantiles < 256 and offset == 1/(512):
# override default arguments
offset = 1/(2*num_quantiles)
if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device) if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
is_on_gpu([A, out]) is_on_gpu([A, out])
device = pre_call(A.device)
if A.dtype == torch.float32: if A.dtype == torch.float32:
lib.cestimate_quantiles_fp32( lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
)
elif A.dtype == torch.float16: elif A.dtype == torch.float16:
lib.cestimate_quantiles_fp16( lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
)
else: else:
raise NotImplementedError(f"Not supported data type {A.dtype}") raise NotImplementedError(f"Not supported data type {A.dtype}")
post_call(device)
if num_quantiles < 256:
idx = torch.linspace(0, 255, num_quantiles).long().to(A.device)
out = out[idx]
return out return out

View File

@ -6,9 +6,11 @@ from itertools import product
import einops import einops
import pytest import pytest
import torch import torch
import numpy as np
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes import functional as F from bitsandbytes import functional as F
from scipy.stats import norm
torch.set_printoptions( torch.set_printoptions(
precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000 precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
@ -2094,8 +2096,12 @@ def test_fp8_quant():
def test_few_bit_quant(): def test_few_bit_quant():
print('')
for bits in range(2, 9): for bits in range(2, 9):
for method in ['linear', 'fp8']: print('='*30, bits, '='*30)
for method in ['linear', 'fp8', 'dynamic', 'quantile']:
abserrs = []
relerrs = []
code = None code = None
if method == 'linear': if method == 'linear':
code = F.create_linear_map(True, bits=bits).cuda() code = F.create_linear_map(True, bits=bits).cuda()
@ -2103,10 +2109,21 @@ def test_few_bit_quant():
ebits = math.ceil(bits/2) ebits = math.ceil(bits/2)
pbits = bits-ebits-1 pbits = bits-ebits-1
code = F.create_fp8_map(True, ebits, pbits, bits).cuda() code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
print(ebits, pbits, bits) elif method == 'dynamic':
print(code) code = F.create_dynamic_map(True, bits-0, bits).cuda()
elif method == 'quantile':
values = torch.randn(2048, 2048, device='cuda')
q = F.estimate_quantiles(values, offset= 1/(2*(2**bits)), num_quantiles=2**bits)
gap = 256-q.numel()
q = q.tolist()
for i in range(gap):
q.append(0)
q = torch.Tensor(q).cuda()
q /= q.abs().max()
code, idx = torch.sort(q)
print(method, (code==0).sum())
assert code.numel() == 256 assert code.numel() == 256
print(bits)
for i in range(10): for i in range(10):
values = torch.randn(1, 32, device='cuda') values = torch.randn(1, 32, device='cuda')
@ -2127,11 +2144,25 @@ def test_few_bit_quant():
v2 = F.dequantize(q2, S2) v2 = F.dequantize(q2, S2)
idx = torch.isclose(q1.int(), q2.int()) idx = torch.isclose(q1.int(), q2.int())
err2 = torch.abs(v2-values)
abserrs.append(err2.mean().item())
relerrs.append((err2/(1e-10+values).abs()).mean().item())
if idx.sum(): if idx.sum():
# some weird cases # some weird cases
err1 = torch.abs(v1-values).mean() err1 = torch.abs(v1-values).mean()
err2 = torch.abs(v2-values).mean() assert err2.mean() <= err1
assert err2 <= err1
else: else:
torch.testing.assert_allclose(q1, q2) torch.testing.assert_allclose(q1, q2)
print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
def test_kbit_quantile_estimation():
for i in range(100):
data = torch.randn(1024, 1024, device='cuda')
for bits in range(2, 9):
p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits)
val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
err = torch.abs(val1-val2).mean()
assert err < 0.035