Added k<256 quantile estimate.
This commit is contained in:
parent
98cbc4bc4f
commit
2f2063bac2
|
@ -182,7 +182,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def create_dynamic_map(signed=True, n=7):
|
def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
|
||||||
"""
|
"""
|
||||||
Creates the dynamic quantiztion map.
|
Creates the dynamic quantiztion map.
|
||||||
|
|
||||||
|
@ -203,28 +203,32 @@ def create_dynamic_map(signed=True, n=7):
|
||||||
# these are additional items that come from the case
|
# these are additional items that come from the case
|
||||||
# where all the exponent bits are zero and no
|
# where all the exponent bits are zero and no
|
||||||
# indicator bit is present
|
# indicator bit is present
|
||||||
additional_items = 2 ** (7 - n) - 1
|
non_sign_bits = total_bits - (1 if signed else 0)
|
||||||
|
additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1
|
||||||
if not signed:
|
if not signed:
|
||||||
additional_items = 2 * additional_items
|
additional_items = 2 * additional_items
|
||||||
for i in range(n):
|
for i in range(max_exponent_bits):
|
||||||
fraction_items = (
|
fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1))
|
||||||
2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1
|
|
||||||
)
|
|
||||||
boundaries = torch.linspace(0.1, 1, fraction_items)
|
boundaries = torch.linspace(0.1, 1, fraction_items)
|
||||||
means = (boundaries[:-1] + boundaries[1:]) / 2.0
|
means = (boundaries[:-1] + boundaries[1:]) / 2.0
|
||||||
data += ((10 ** (-(n - 1) + i)) * means).tolist()
|
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
|
||||||
if signed:
|
if signed:
|
||||||
data += (-(10 ** (-(n - 1) + i)) * means).tolist()
|
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
|
||||||
|
|
||||||
if additional_items > 0:
|
if additional_items > 0:
|
||||||
boundaries = torch.linspace(0.1, 1, additional_items + 1)
|
boundaries = torch.linspace(0.1, 1, additional_items + 1)
|
||||||
means = (boundaries[:-1] + boundaries[1:]) / 2.0
|
means = (boundaries[:-1] + boundaries[1:]) / 2.0
|
||||||
data += ((10 ** (-(n - 1) + i)) * means).tolist()
|
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
|
||||||
if signed:
|
if signed:
|
||||||
data += (-(10 ** (-(n - 1) + i)) * means).tolist()
|
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
|
||||||
|
|
||||||
data.append(0)
|
data.append(0)
|
||||||
data.append(1.0)
|
data.append(1.0)
|
||||||
|
|
||||||
|
gap = 256 - len(data)
|
||||||
|
for i in range(gap):
|
||||||
|
data.append(0)
|
||||||
|
|
||||||
data.sort()
|
data.sort()
|
||||||
return Tensor(data)
|
return Tensor(data)
|
||||||
|
|
||||||
|
@ -371,9 +375,7 @@ def nvidia_transform(
|
||||||
return out, new_state
|
return out, new_state
|
||||||
|
|
||||||
|
|
||||||
def estimate_quantiles(
|
def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor:
|
||||||
A: Tensor, out: Tensor = None, offset: float = 1 / 512
|
|
||||||
) -> Tensor:
|
|
||||||
'''
|
'''
|
||||||
Estimates 256 equidistant quantiles on the input tensor eCDF.
|
Estimates 256 equidistant quantiles on the input tensor eCDF.
|
||||||
|
|
||||||
|
@ -393,25 +395,36 @@ def estimate_quantiles(
|
||||||
out : torch.Tensor
|
out : torch.Tensor
|
||||||
Tensor with the 256 estimated quantiles.
|
Tensor with the 256 estimated quantiles.
|
||||||
offset : float
|
offset : float
|
||||||
The offset for the first and last quantile from 0 and 1. Default: 1/512
|
The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles)
|
||||||
|
num_quantiles : int
|
||||||
|
The number of equally spaced quantiles.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor:
|
torch.Tensor:
|
||||||
The 256 quantiles in float32 datatype.
|
The 256 quantiles in float32 datatype.
|
||||||
'''
|
'''
|
||||||
|
if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.')
|
||||||
|
if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}")
|
||||||
|
if num_quantiles < 256 and offset == 1/(512):
|
||||||
|
# override default arguments
|
||||||
|
offset = 1/(2*num_quantiles)
|
||||||
|
|
||||||
if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
|
if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
|
||||||
is_on_gpu([A, out])
|
is_on_gpu([A, out])
|
||||||
|
device = pre_call(A.device)
|
||||||
if A.dtype == torch.float32:
|
if A.dtype == torch.float32:
|
||||||
lib.cestimate_quantiles_fp32(
|
lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
|
||||||
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
|
|
||||||
)
|
|
||||||
elif A.dtype == torch.float16:
|
elif A.dtype == torch.float16:
|
||||||
lib.cestimate_quantiles_fp16(
|
lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
|
||||||
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Not supported data type {A.dtype}")
|
raise NotImplementedError(f"Not supported data type {A.dtype}")
|
||||||
|
post_call(device)
|
||||||
|
|
||||||
|
if num_quantiles < 256:
|
||||||
|
idx = torch.linspace(0, 255, num_quantiles).long().to(A.device)
|
||||||
|
out = out[idx]
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -6,9 +6,11 @@ from itertools import product
|
||||||
import einops
|
import einops
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
from bitsandbytes import functional as F
|
from bitsandbytes import functional as F
|
||||||
|
from scipy.stats import norm
|
||||||
|
|
||||||
torch.set_printoptions(
|
torch.set_printoptions(
|
||||||
precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
|
precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
|
||||||
|
@ -2094,8 +2096,12 @@ def test_fp8_quant():
|
||||||
|
|
||||||
def test_few_bit_quant():
|
def test_few_bit_quant():
|
||||||
|
|
||||||
|
print('')
|
||||||
for bits in range(2, 9):
|
for bits in range(2, 9):
|
||||||
for method in ['linear', 'fp8']:
|
print('='*30, bits, '='*30)
|
||||||
|
for method in ['linear', 'fp8', 'dynamic', 'quantile']:
|
||||||
|
abserrs = []
|
||||||
|
relerrs = []
|
||||||
code = None
|
code = None
|
||||||
if method == 'linear':
|
if method == 'linear':
|
||||||
code = F.create_linear_map(True, bits=bits).cuda()
|
code = F.create_linear_map(True, bits=bits).cuda()
|
||||||
|
@ -2103,10 +2109,21 @@ def test_few_bit_quant():
|
||||||
ebits = math.ceil(bits/2)
|
ebits = math.ceil(bits/2)
|
||||||
pbits = bits-ebits-1
|
pbits = bits-ebits-1
|
||||||
code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
|
code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
|
||||||
print(ebits, pbits, bits)
|
elif method == 'dynamic':
|
||||||
print(code)
|
code = F.create_dynamic_map(True, bits-0, bits).cuda()
|
||||||
|
elif method == 'quantile':
|
||||||
|
values = torch.randn(2048, 2048, device='cuda')
|
||||||
|
q = F.estimate_quantiles(values, offset= 1/(2*(2**bits)), num_quantiles=2**bits)
|
||||||
|
gap = 256-q.numel()
|
||||||
|
q = q.tolist()
|
||||||
|
for i in range(gap):
|
||||||
|
q.append(0)
|
||||||
|
q = torch.Tensor(q).cuda()
|
||||||
|
|
||||||
|
q /= q.abs().max()
|
||||||
|
code, idx = torch.sort(q)
|
||||||
|
print(method, (code==0).sum())
|
||||||
assert code.numel() == 256
|
assert code.numel() == 256
|
||||||
print(bits)
|
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
|
|
||||||
values = torch.randn(1, 32, device='cuda')
|
values = torch.randn(1, 32, device='cuda')
|
||||||
|
@ -2127,11 +2144,25 @@ def test_few_bit_quant():
|
||||||
v2 = F.dequantize(q2, S2)
|
v2 = F.dequantize(q2, S2)
|
||||||
|
|
||||||
idx = torch.isclose(q1.int(), q2.int())
|
idx = torch.isclose(q1.int(), q2.int())
|
||||||
|
err2 = torch.abs(v2-values)
|
||||||
|
abserrs.append(err2.mean().item())
|
||||||
|
relerrs.append((err2/(1e-10+values).abs()).mean().item())
|
||||||
if idx.sum():
|
if idx.sum():
|
||||||
# some weird cases
|
# some weird cases
|
||||||
err1 = torch.abs(v1-values).mean()
|
err1 = torch.abs(v1-values).mean()
|
||||||
err2 = torch.abs(v2-values).mean()
|
assert err2.mean() <= err1
|
||||||
assert err2 <= err1
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
torch.testing.assert_allclose(q1, q2)
|
torch.testing.assert_allclose(q1, q2)
|
||||||
|
print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
|
||||||
|
|
||||||
|
|
||||||
|
def test_kbit_quantile_estimation():
|
||||||
|
for i in range(100):
|
||||||
|
data = torch.randn(1024, 1024, device='cuda')
|
||||||
|
for bits in range(2, 9):
|
||||||
|
p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits)
|
||||||
|
val1 = torch.Tensor(norm.ppf(p)).cuda()
|
||||||
|
val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
|
||||||
|
err = torch.abs(val1-val2).mean()
|
||||||
|
assert err < 0.035
|
||||||
|
|
Loading…
Reference in New Issue
Block a user