
2528 lines
84 KiB

import math
import random
import time
from itertools import product
import einops
import pytest
import torch
import numpy as np
import bitsandbytes as bnb
from bitsandbytes import functional as F
from scipy.stats import norm
precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
k = 20
def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True):
idx = torch.isclose(a, b, rtol, atol)
sumval = (idx == 0).sum().item()
if sumval > count:
if throw:
print(f"Too many values not close: assert {sumval} < {count}")
torch.testing.assert_close(a, b, rtol, atol)
return sumval
class FFN(torch.nn.Module):
def __init__(self, input_features, hidden_size, bias=True):
self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias)
self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias)
with torch.no_grad():
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
class Timer:
def __init__(self):
self.starts = {}
self.ends = {}
self.agg = {}
def tick(self, name="default"):
if name not in self.starts:
self.starts[name] = torch.cuda.Event(enable_timing=True)
self.ends[name] = torch.cuda.Event(enable_timing=True)
ms = self.tock(name, evict=True, print_ms=False)
def tock(self, name="default", evict=True, print_ms=True):
if name in self.ends:
ms = self.starts[name].elapsed_time(self.ends[name])
if name not in self.agg:
self.agg[name] = 0.0
self.agg[name] += ms
if evict:
if print_ms and name in self.agg:
print(f"{name} took: {self.agg[name] / 1000.0:.5f}s")
return self.agg[name]
def reset(self):
self.starts = {}
self.ends = {}
self.agg = {}
print("Resetting benchmark data")
def setup():
def teardown():
"dtype", [torch.float32, torch.float16], ids=["float", "half"]
def test_estimate_quantiles(dtype):
A = torch.rand(1024, 1024, device="cuda")
A =
code = F.estimate_quantiles(A)
percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2)
A = torch.randn(1024, 1024, device="cuda")
A =
code = F.estimate_quantiles(A)
quantiles = torch.quantile(A.float(), percs)
diff = torch.abs(code - quantiles)
assert (diff > 5e-02).sum().item() == 0
def test_quantile_quantization():
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code)
diff = torch.abs(A1 - A2).mean().item()
assert diff < 0.0075
A1 = torch.rand(1024, 1024, device="cuda")
code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code)
diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0)
assert diff < 0.001
def test_dynamic_quantization():
diffs = []
reldiffs = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, S = F.quantize(A1)
A2 = F.dequantize(C, S)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
assert diff.mean().item() < 0.0135
# print(sum(diffs)/len(diffs))
# print(sum(reldiffs)/len(reldiffs))
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda")
C, S = F.quantize(A1)
A2 = F.dequantize(C, S)
diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
assert diff < 0.004
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"])
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
def test_dynamic_blockwise_quantization(dtype, nested, blocksize):
diffs = []
reldiffs = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
A2 = F.dequantize_blockwise(C, S)
diff = torch.abs(A1 - A2).float()
reldiff = diff / torch.abs(A1.float() + 1e-8)
abserr = sum(diffs)/len(diffs)
relerr = sum(reldiffs)/len(reldiffs)
#print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
#print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
assert abserr < 0.011
assert relerr < 0.018
assert A2.dtype == dtype
diffs = []
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype)
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
A2 = F.dequantize_blockwise(C, S)
diff = torch.abs(A1 - A2).float()
reldiff = diff / torch.abs(A1.float() + 1e-8)
#torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
abserr = sum(diffs)/len(diffs)
relerr = sum(reldiffs)/len(reldiffs)
assert abserr < 0.0035
assert relerr < 0.015
assert A2.dtype == dtype
#print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
#print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
"gtype", [torch.float32, torch.float16], ids=["float", "half"]
def test_percentile_clipping(gtype):
gnorm_vec1 = torch.zeros(100, device="cuda")
gnorm_vec2 = torch.zeros(100, device="cuda")
n = 4
step = 0
percentile = 5
for i in range(k):
step += 1
g = torch.randn(n, n, dtype=gtype, device="cuda")
gnorm1, clip2, gnorm_scale = F.percentile_clipping(
g, gnorm_vec2, step, percentile=percentile
assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
gnorm2 = torch.norm(g.float())
if step == 1:
gnorm_vec1[:] = gnorm2
gnorm_vec1[step % 100] = gnorm2
vals, idx = torch.sort(gnorm_vec1)
clip1 = vals[percentile]
torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2))
torch.testing.assert_close(clip1, clip2)
torch.testing.assert_close(gnorm1, gnorm2)
def quant(x):
max1 = torch.abs(x).max()
x = torch.round(x / max1 * 127)
return max1,
def dequant(c, maxC):
return c.float() * (maxC / 127)
def mm_dequant(maxA, maxB, C):
return C.float() * (maxA / 127) * (maxB / 127)
def quant_multi(x, dim):
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
max1[max1 == 0] = 1.0
x = torch.round(x / max1 * 127)
return max1,
def quant_multi_chunk(x, dim, chunk_size=32):
if dim == 1:
x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size)
max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True)
max1 = torch.tile(max1, (1, 1, x.shape[1]))
max1 = max1.view(x.shape)
elif dim == 0:
x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size)
max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True)
max1 = torch.tile(max1, (x.shape[0], 1, 1))
max1 = max1.view(x.shape)
max1[max1 == 0] = 1.0
x = torch.round(x / max1 * 127)
return max1,
def quant_minmax(A):
minA = A.min()
maxA = A.max()
def mean(xx):
return sum(xx) / float(len(xx))
# dim1 = torch.randint(1,1024*4, size=(4,)).tolist()
# dim2 = torch.randint(1,1024*4, size=(4,)).tolist()
dim1 = [1024 * 2]
dim2 = [1024 * 16]
methods = [
lambda x, dim: quant(x),
lambda x, dim: quant(x),
methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant))
# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
method_names = ["linear", "vectorwise"]
batched = [False, True]
values = list(product(dim1, dim2, methods, batched))
values_names = list(product(dim1, dim2, method_names, batched))
names = [
for vals in values_names
"dim1, dim2, quant_methods, batched", values, ids=names
def test_approx_igemm(dim1, dim2, quant_methods, batched):
dim1 = dim1 - (dim1 % 32)
dim2 = dim2 - (dim2 % 32)
errors = []
relerrors = []
for i in range(5):
if batched:
A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda")
maxA, Ac = quant_methods[0](A, 2)
maxB, Bc = quant_methods[1](B, 1)
A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda")
B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
maxA, Ac = quant_methods[0](A, 1)
maxB, Bc = quant_methods[1](B, 0)
quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05
if batched:
out2 = torch.bmm(A, B)
C = torch.bmm(Ac.float(), Bc.float())
out2 =, B)
C = F.igemm(Ac, Bc)
out = quant_methods[4](maxA, maxB, C)
std = out2.std()
out /= std
out2 /= std
err = torch.abs(out - out2)
relerr = err / torch.abs(out2)
def test_stable_embedding():
layer = bnb.nn.StableEmbedding(1024, 1024)
n = 2
hidden_dim = torch.randint(32, 256, size=(n,)).tolist()
batch_dim = torch.randint(16, 256, size=(n,)).tolist()
seq_dim = torch.randint(16, 256, size=(n,)).tolist()
transpose = [(False, False), (False, True), (True, False), (True, True)]
values = list(product(hidden_dim, batch_dim, transpose, seq_dim))
names = [
for vals in values
"hidden_dim, batch_dim, transpose, seq_dim", values, ids=names
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
hidden_dim = hidden_dim - (hidden_dim % 32)
batch_dim = batch_dim - (batch_dim % 16)
seq_dim = seq_dim - (seq_dim % 16)
for i in range(k):
shapeA = (
(batch_dim, hidden_dim)
if not transpose[0]
else (hidden_dim, batch_dim)
shapeB = (
(32 * random.randint(1, 4), hidden_dim)
if transpose[1]
else (hidden_dim, 32 * random.randint(1, 4))
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]:
out2 = torch.matmul(A.float(), B.float())
out = F.igemm(A, B)
elif not transpose[0] and transpose[1]:
out2 = torch.matmul(A.float(), B.t().float())
out = F.igemm(A, B.t())
elif transpose[0] and not transpose[1]:
out2 = torch.matmul(A.t().float(), B.float())
out = F.igemm(A.t(), B)
elif transpose[0] and transpose[1]:
out2 = torch.matmul(A.t().float(), B.t().float())
out = F.igemm(A.t(), B.t())
torch.testing.assert_close(out.float(), out2)
for i in range(k):
shapeA = (batch_dim, seq_dim, hidden_dim)
shapeB = (
(32 * random.randint(1, 4), hidden_dim)
if transpose[1]
else (hidden_dim, 32 * random.randint(1, 4))
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]:
out2 = torch.matmul(A.float(), B.float())
out = F.igemm(A, B)
elif not transpose[0] and transpose[1]:
out2 = torch.matmul(A.float(), B.t().float())
out = F.igemm(A, B.t())
torch.testing.assert_close(out.float(), out2)
n = 3
seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
values = list(product(seq_dim, hidden_dim, batch_dim))
names = [
"seq_dim{}_hidden_dim{}_batch_dim{}".format(*vals) for vals in values
@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names)
def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
seq_dim = seq_dim - (seq_dim % 32)
hidden_dim = hidden_dim - (hidden_dim % 32)
batch_dim = batch_dim - (batch_dim % 2)
for i in range(25):
A = torch.randint(
-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
B = torch.randint(
-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda"
out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
iout = torch.empty(
A.shape[2], B.shape[2], dtype=torch.int32, device=A.device
out = F.igemm(A, B, out=iout)
torch.testing.assert_close(out.float(), out2)
n = 2
seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
transpose = [False, True]
values = list(product(seq_dim, hidden_dim, batch_dim, transpose))
names = [
for vals in values
"seq_dim, hidden_dim, batch_dim, transpose", values, ids=names
def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
def min_max(x):
maxA = torch.amax(x, dim=2, keepdim=True)
minA = torch.amin(x, dim=2, keepdim=True)
scale = (maxA - minA) / 2.0
return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale
seq_dim = seq_dim - (seq_dim % 16)
hidden_dim = hidden_dim - (hidden_dim % 16)
batch_dim = batch_dim - (batch_dim % 2)
errs = []
relerrs = []
errs2 = []
relerrs2 = []
for i in range(k):
A = torch.normal(
0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
if transpose:
B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda")
Ac, minA, scale = min_max(A)
if transpose:
maxB, Bc = quant_multi(B, dim=(1 if transpose else 0))
out = F.igemm(Ac, Bc.t())
out2 = torch.matmul(A, B.t())
offset = B.t().sum(0) * (minA + scale)
out = out.float()
out = (out * maxB.t() * scale / (127 * 127)) + offset
maxA, Ac = quant_multi(A, dim=2)
out3 = F.igemm(Ac, Bc.t())
out3 = mm_dequant(maxA, maxB.t(), out3)
maxB, Bc = quant_multi(B, dim=0)
offset = B.sum(0) * (minA + scale)
out = F.igemm(Ac, Bc)
out2 = torch.matmul(A, B)
out = out.float()
out = (out * maxB * scale / (127 * 127)) + offset
maxA, Ac = quant_multi(A, dim=2)
out3 = F.igemm(Ac, Bc)
out3 = mm_dequant(maxA, maxB, out3)
std = out2.std()
out2 /= std
out /= std
out3 /= std
err = torch.abs(out - out2)
relerr = err / (torch.abs(out2) + 1e-7)
err2 = torch.abs(out3 - out2)
relerr2 = err2 / (torch.abs(out2) + 1e-7)
# print(mean(errs))
# print(mean(relerrs))
# print(mean(errs2))
# print(mean(relerrs2))
assert mean(errs) < 0.015
assert mean(relerrs) < 0.3
n = 2
dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist()
dim4 = torch.randint(32, 256, size=(n,)).tolist()
transpose = [(False, False), (True, False), (False, True), (True, True)]
values = list(product(dim1, dim2, dim3, dim4, transpose))
names = [
for vals in values
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names)
def test_ibmm(dim1, dim2, dim3, dim4, transpose):
dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
dim4 = dim4 - (dim4 % 16)
for i in range(k):
shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4)
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]:
out2 = torch.bmm(A.float(), B.float())
out = F.igemm(A, B)
elif not transpose[0] and transpose[1]:
out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float())
out = F.igemm(A, B.permute([0, 2, 1]))
elif transpose[0] and not transpose[1]:
out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
out = F.igemm(A.permute([0, 2, 1]), B)
elif transpose[0] and transpose[1]:
out2 = torch.bmm(
A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()
out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
torch.testing.assert_close(out.float(), out2.float())
n = 1
dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist()
values = list(product(dim1, dim2, dim3))
names = ["dim1_{}_dim2_{}_dim3_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names)
def test_vector_quant(dim1, dim2, dim3):
dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
for i in range(k):
A = torch.randn(size=(dim2, dim3), device="cuda")
qA, SA = F.vectorwise_quant(A, dim=0)
A1 = F.vectorwise_dequant(qA, SA)
n = A1.numel()
assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n*0.002))
n = 2
dim1 = torch.randint(2, 256, size=(n,)).tolist()
dim2 = torch.randint(2, 256, size=(n,)).tolist()
dim3 = torch.randint(2, 256, size=(n,)).tolist()
# dim1, dim2 = (256,), (256,)
dtype = [torch.int8, torch.int32]
a_order = ["row"]
out_order = ["col", "row", "col32"]
transpose = [False]
dims = [2, 3]
values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))
names = ["dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_transpose_{}".format(*vals)for vals in values]
@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",values,ids=names)
def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
if dims == 3 and out_order != "col32":
if dtype == torch.int32 and out_order != "col32":
func = F.get_transform_func(dtype, orderA, orderOut, transpose)
if dims == 2:
A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
elif dims == 3:
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
out, S = F.nvidia_transform(A, to_order=orderOut)
if orderOut == "row":
torch.testing.assert_close(A.flatten(), out.flatten())
elif orderOut == "col":
torch.testing.assert_close(A.t().flatten(), out.flatten())
elif orderOut == "col32":
if dims == 2:
n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
elif dims == 3:
n = (
* A.shape[1]
* (A.shape[2] + (32 - (A.shape[2] % 32)))
assert out.numel() == n
elif orderOut == "col_turing":
# 32 col 8 row tiles
n = (A.shape[0] + (8 - A.shape[0] % 8)) * (
A.shape[1] + (32 - (A.shape[1] % 32))
assert out.numel() == n
total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
for row in range(A.shape[0]):
for col in range(A.shape[1]):
i = row * A.shape[1]
j = col
coltile = (col // 32) + (1 if col % 32 != 0 else 0)
rowtile = (
(row // 8) + (1 if row % 8 != 0 else 0)
) * total_coltile
offset = 32 * 8 * (rowtile + coltile)
col2 = col % 32
row2 = (row % 8) * 32
assert A.flatten()[i + j] == A[row, col]
# assert A.flatten()[i+j] == out.flatten()[row2+col2]
# torch.testing.assert_close(A.flatten()[i+j], A[row, col])
# torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
if orderOut == "col32":
out2, S = F.nvidia_transform(
out, from_order=orderOut, to_order="row", state=S
torch.testing.assert_close(A, out2)
n = 1
dim1 = torch.randint(1, 256, size=(n,)).tolist()
dim2 = torch.randint(32, 512, size=(n,)).tolist()
dim3 = torch.randint(32, 1024, size=(n,)).tolist()
dim4 = torch.randint(32, 1024, size=(n,)).tolist()
# dim1 = [2]
# dim2 = [2]
# dim3 = [2]
# dim4 = [2]
dims = (2, 3)
ldb = [0]
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
names = [
for vals in values
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names)
def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
for i in range(k):
if dims == 2:
A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(
elif dims == 3:
A = torch.randint(
-128, 127, size=(dim1, dim2, dim3), device="cuda"
B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(
C1 = torch.matmul(A.float(), B.t().float())
A2, SA = F.transform(A, "col32")
B2, SB = F.transform(B, "col_turing")
C2, SC = F.igemmlt(A2, B2, SA, SB)
C3, S = F.nvidia_transform(C2, "row", state=SC)
torch.testing.assert_close(C1, C3.float())
# transpose
B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
C1 = torch.matmul(A.float(), B.float())
B2t, SBt = F.transform(B, "col_turing", transpose=True)
C2, SC = F.igemmlt(A2, B2t, SA, SBt)
C3, S = F.nvidia_transform(C2, "row", state=SC)
torch.testing.assert_close(C1, C3.float())
dim1 = [32]
dim2 = [32]
dim3 = [32]
dim4 = [32]
dims = (2,)
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims))
names = [
for vals in values
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names)
def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
formatB = F.get_special_format_str()
for i in range(k):
if dims == 2:
A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
elif dims == 3:
A = torch.normal(
0, 0.5, size=(dim1, dim2, dim3), device="cuda"
B = torch.randn((dim4, dim3), device="cuda").half()
C1 = torch.matmul(A, B.t())
C2 = bnb.matmul(A, B.t())
A = A.view(-1, A.shape[-1])
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
C32A, SA = F.transform(CA, "col32")
CxB, SB = F.transform(CB, to_order=formatB)
out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB)
output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt)
# print('')
# print(output.flatten()[:10])
# print(C1.flatten()[:10])
# print(C2.flatten()[:10])
# torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
# transpose
# B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
# C1 = torch.matmul(A.float(), B.float())
# B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
# C2, SC = F.igemmlt(A2, B2t, SA, SBt)
# C3, S = F.transform(C2, 'row', state=SC)
# torch.testing.assert_close(C1, C3.float())
batch_size = 2
seqdim = 512
# values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
values = [
(batch_size, seqdim, 4 * 1024, 3 * 4 * 1024),
(batch_size, seqdim, 5120, 3 * 5120),
(batch_size, seqdim, 12 * 1024, 4 * 12 * 1024),
# values = list(product(batch, seq, model, hidden))
names = [
"batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_8bit_training(batch, seq, model, hidden):
formatB = F.get_special_format_str()
A = torch.randn(batch, seq, model, device="cuda").half()
grad = torch.randn(batch, seq, model, device="cuda").half()
w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half()
w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half()
# torch.cuda.synchronize()
## warmup
# for i in range(100):
# torch.matmul(A, w1.t())
# torch.cuda.synchronize()
dtype = torch.int8
A = A.view(-1, A.shape[-1]).contiguous()
grad = grad.view(-1, grad.shape[-1]).contiguous()
t0 = time.time()
for i in range(k):
out1 = torch.matmul(A, w1.t()) # fc1
# out2 = torch.matmul(out1, w2.t())# fc2
# d1 = torch.matmul(grad, w2) # delta1
# d2 = torch.matmul(d1, w1) # delta2
# grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
# grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
t16 = time.time() - t0
# torch.cuda.empty_cache()
# Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
# CTw1, Sw1 = F.transform2(Cw1, formatB)
# CTw2, Sw2 = F.transform2(Cw2, formatB)
# CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
# CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
# CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
# C32A, SA = F.transform2(CA, 'col32')
## fc1
# out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)
## fc2
# Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
# C32out1, Sout1 = F.transform2(Cout1, 'col32')
# out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)
## delta1
# Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
# C32grad, Sgrad = F.transform2(Cgrad, 'col32')
##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)
## delta2
# Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
# C32d1, Sd1 = F.transform2(Cd1, 'col32')
##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)
## grad1
# C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
# CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)
## grad2
# C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
# CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)
# Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
# Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
# CTw1, Sw1 = F.transform2(Cw1, formatB)
# CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
# CTw2, Sw2 = F.transform2(Cw2, formatB)
# CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(k):
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# #CTw1, Sw1 = F.transform2(Cw1, formatB)
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# #CTw1, Sw1 = F.transform2(Cw1, formatB)
# #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=3.5)
# CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
# #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
# #CTw2, Sw2 = F.transform2(Cw2, formatB)
# #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
# C32A, SA = F.transform2(CA, 'col32')
# # fc1
# out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
# #out1dn = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
# #print(coo_tensor.nnz)
# #out1sp = F.spmm_coo(coo_tensor, w1.t())
# #print(w1.t().shape)
# #out1 = out1dn + out1sp
# # fc2
# Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
# C32out1, Sout1 = F.transform2(Cout1, 'col32')
# out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
# #out2 = F.mm_dequant(out2_32, Sout2_32, statsout1, statsw2)
# # delta1
# Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
# C32grad, Sgrad = F.transform2(Cgrad, 'col32')
# d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
# #d1 = F.mm_dequant(d1_32, Sd1_32, statsgrad, statsw2t)
# # delta2
# Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
# C32d1, Sd1 = F.transform2(Cd1, 'col32')
# d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
# #d2 = F.mm_dequant(d2_32, Sd2_32, statsd1, statsw1t)
# # grad1
# #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
# #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
# #grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
# #grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1t, statsgradt)
# ## grad2
# #C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
# #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
# #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
# #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)
# torch.cuda.synchronize()
# t8 = time.time() - t0
# print(t8)
n = 2
dim1 = torch.randint(64, 256, size=(n,)).tolist()
dim4 = torch.randint(64, 1024, size=(n,)).tolist()
#dim1 = [2*1024]
#dim4 = [2*1024]
#dim1 = [4]
#dim4 = [4]
dims = (2,)
formatB = ["col_turing", "col_ampere"]
has_bias = [True, False]
values = list(product(dim1, dim4, dims, formatB, has_bias))
names = ["dim1_{}_dim4_{}_dims_{}_formatB_{}_has_bias_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names)
def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
inner = torch.randint(1, 128, size=(1,)).item()
bias = None
if has_bias: bias = torch.randn(dim4, device='cuda', dtype=torch.float16)
formatB = F.get_special_format_str()
for i in range(1):
A = torch.randn(dim1, inner, device="cuda")
B = torch.randn(dim4, inner, device="cuda")
C1 = torch.matmul(A.half(), B.t().half())
if has_bias: C1 += bias
A1, maxA = F.vectorwise_quant(A, dim=1)
B1, maxB = F.vectorwise_quant(B, dim=1)
A2, SA = F.nvidia_transform(A1, "col32")
B2, SB = F.nvidia_transform(B1, formatB)
C2, SC = F.igemmlt(A2, B2, SA, SB)
C3, S = F.nvidia_transform(C2, "row", state=SC)
C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
if has_bias: C4 += bias
# TODO: is something wrong here? If so, the problem goes deeper
#n = C1.numel()
#p = 0.06
std = C1.std(0).view(1, -1)
C1 /= std
C4 /= std
#assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06))
#assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"
C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias)
#torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1)
n = C5.numel()
assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01*n))
n = 2
dim1 = [1 * 1024]
dim2 = [1 * 1024]
# dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dims = (2,)
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dims))
names = ["dim1_{}_dim2_{}_dims_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names)
def test_colrow_absmax(dim1, dim2, dims):
for i in range(k):
threshold = 3.0
A = torch.randn(dim1, dim2, device="cuda").half()
A_truncated = A.clone()
A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0
if dims == 2:
row_stats1, _ = torch.abs(A.float()).max(1)
col_stats1, _ = torch.abs(A.float()).max(0)
row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1)
col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0)
assert False
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
A, threshold=threshold
A_blocked = einops.rearrange(
"(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size",
block_size=64 * 4,
nnz_rows1_counts = (torch.abs(A_blocked) >= threshold).sum(3).flatten()
nnz_block_ptr1 = torch.zeros(
nnz_rows1_counts.shape[0] + 1,
nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)
torch.testing.assert_close(col_stats1_trunc, col_stats2)
torch.testing.assert_close(row_stats1_trunc, row_stats2)
torch.testing.assert_close(, nnz_block_ptr2)
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
A, threshold=0.0
torch.testing.assert_close(col_stats1, col_stats2)
torch.testing.assert_close(row_stats1, row_stats2)
assert nnz_block_ptr2 is None
n = 2
# dim1 = [8*1024]
# dim2 = [4*1024]
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2))
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_double_quant(dim1, dim2):
for i in range(k):
A = torch.randn(dim1, dim2, device="cuda").half()
out_col1, Scol = F.vectorwise_quant(A, dim=0)
out_row1, Srow = F.vectorwise_quant(A, dim=1)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
# max difference is 1 due to rounding differences
torch.testing.assert_close(CA, out_row1, atol=1, rtol=0)
torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0)
n = CAt.numel()
num_not_close_rows = (
(torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
num_not_close_cols = (
(torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
# allow for 1:500 error due to rounding differences
min_error = 1 / 500
if num_not_close_cols > (min_error * n):
f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}"
assert False
if num_not_close_rows > (min_error * n):
f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}"
assert False
torch.testing.assert_close(Srow.flatten().float(), statsA)
torch.testing.assert_close(Scol.flatten().float(), statsAt)
n = 4
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(zip(dim1, dim4, inner))
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
def test_integrated_igemmlt(dim1, dim4, inner):
for i in range(k):
A = torch.randn(dim1, inner, device="cuda").half()
B = torch.randn(dim4, inner, device="cuda").half()
out1 = torch.matmul(A.half(), B.t().half())
C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
A1, maxA = F.vectorwise_quant(A, dim=1)
B1, maxB = F.vectorwise_quant(B, dim=1)
torch.testing.assert_close(maxA.flatten().float(), stats1a)
torch.testing.assert_close(maxB.flatten().float(), stats2a)
torch.testing.assert_close(C1a, A1, rtol=0, atol=1)
torch.testing.assert_close(C2a, B1, rtol=0, atol=1)
A2, SA = F.nvidia_transform(C1a, "col32")
B2, SB = F.nvidia_transform(C2a, "col_turing")
outC32, SC = F.igemmlt(A2, B2, SA, SB)
out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)
A2, SA = F.nvidia_transform(A1, "col32")
B2, SB = F.nvidia_transform(B1, "col_turing")
C2, SC = F.igemmlt(A2, B2, SA, SB)
C3, S = F.nvidia_transform(C2, "row", state=SC)
out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
err1 = torch.abs(out1 - out2).mean().item()
err2 = torch.abs(out1 - out3).mean().item()
assert err2 <= err1 * 1.025
n = 6
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(zip(dim1, dim4, inner))
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
@pytest.mark.skip("Row scale has some bugs for ampere")
def test_igemmlt_row_scale(dim1, dim4, inner):
formatB = F.get_special_format_str()
err1, err2, err3 = [], [], []
relerr1, relerr2 = [], []
scale = 1
for i in range(k):
A = torch.randn(dim1, inner, device="cuda").half()
B = torch.randn(dim4, inner, device="cuda").half()
C1 = torch.matmul(A, B.t())
out1 = torch.matmul(A.half(), B.t().half())
C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
A2, SA = F.nvidia_transform(C1a, "col32")
B2, SB = F.nvidia_transform(CB, formatB)
A1, maxA = F.vectorwise_quant(A, dim=1)
c = 10.0 * inner * scale
row_scale = torch.ones_like(maxA) / c
outC32, SC = F.igemmlt(
A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
C3, S = F.nvidia_transform(outC32, "row", state=SC)
maxval = torch.abs(C3).max()
if maxval == 127:
scale = 1.5
scale = maxval / 120
out3 = C3 * maxA * absmaxB * c / (127 * 127)
C4 = torch.matmul(C1a.float(), CB.float().t())
C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
B2, SB = F.nvidia_transform(C2a, formatB)
outC32, SC = F.igemmlt(A2, B2, SA, SB)
out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)
CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector")
CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear")
C = torch.matmul(CA.float(), CB.t().float())
out4 = C * SA * SB / (127 * 127)
# out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
# print('='*80)
# print(out1)
# print(out2)
# print(out3)
# print(out1)
# print(out2)
# print(out3)
err1.append(torch.abs(out1 - out2).mean().item())
err2.append(torch.abs(out1 - out3).mean().item())
err3.append(torch.abs(out1 - out4).mean().item())
# assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
print(sum(err1) / len(err1))
print(sum(err2) / len(err2))
print(sum(err3) / len(err3))
dim1 = [1024, 2048]
inner = [12288 * 4, 4096 * 4]
dim4 = [12288, 4096]
values = list(zip(dim1, dim4, inner))
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
@pytest.mark.skip("Row scale has some bugs for ampere")
def test_row_scale_bench(dim1, dim4, inner):
err1, err2, err3 = [], [], []
relerr1, relerr2 = [], []
scale = 1
A = torch.randn(dim1, inner, device="cuda").half()
B = torch.randn(dim4, inner, device="cuda").half()
# warmpup
for i in range(k):
C1 = torch.matmul(A, B.t())
t0 = time.time()
for i in range(k):
C1 = torch.matmul(A, B.t())
print("16", time.time() - t0)
C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
A2, SA = F.nvidia_transform(C1a, "col32")
B2, SB = F.nvidia_transform(CB, formatB)
A1, maxA = F.vectorwise_quant(A, dim=1)
c = 10.0 * inner * scale
row_scale = maxA / c
t0 = time.time()
for i in range(k):
outC32, SC = F.igemmlt(
A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
print("row-wise", time.time() - t0)
C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
B2, SB = F.nvidia_transform(C2a, formatB)
t0 = time.time()
for i in range(k):
outC32, SC = F.igemmlt(A2, B2, SA, SB)
print("vector-wise", time.time() - t0)
n = 2
dim1 = torch.randint(2, 1024, size=(n,)).tolist()
dim2 = torch.randint(2, 1024, size=(n,)).tolist()
# dim1 = [8*1024]
# dim2 = [4*1024]
dim3 = [0]
dtype = [torch.int8]
a_order = ["row"]
out_order = ["col32", "col_turing", "col_ampere"]
transpose = [False, True]
dims = [2]
values = list(
product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
names = [
for vals in values
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
for i in range(k):
if dims == 2:
A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(
elif dims == 3:
A = torch.randint(
10, 99, size=(dim1, dim2, dim3), device="cuda"
A.view(-1)[-1] = -1
if transpose:
At = A.t().contiguous()
out1, S1 = F.nvidia_transform(At, to_order=orderOut)
out1, S1 = F.nvidia_transform(A, to_order=orderOut)
out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose)
assert S1[0][0] == S2[0][0]
assert S1[0][1] == S2[0][1]
# print(out1)
# print(out2)
torch.testing.assert_close(out1, out2)
n = 2
# dim1 = torch.randint(2,1024, size=(n,)).tolist()
# dim2 = torch.randint(2,1024, size=(n,)).tolist()
dim1 = [1]
dim2 = [33]
dtype = [torch.int8]
# a_order = ['col_turing', 'col_ampere']
a_order = ["col_turing"]
out_order = ["row"]
values = list(product(dim1, dim2, dtype, a_order, out_order))
names = [
for vals in values
def test_overflow():
formatB = F.get_special_format_str()
for i in range(2):
a = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
b = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
Ca, Sa = F.nvidia_transform(a, "col32")
Cb, Sb = F.nvidia_transform(b, formatB)
c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8)
c2 = torch.matmul(a.float(), b.float().t())
n = 2
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
# dim1 = [4]
# dim2 = [5]
values = list(product(dim1, dim2))
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_coo_double_quant(dim1, dim2):
threshold = 3.00
for i in range(k):
A = torch.randn(dim1, dim2, device="cuda").half()
idx = torch.abs(A) >= threshold
CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
A, threshold=threshold
if coo_tensor is not None:
A1 = A * idx
A2 = torch.zeros_like(A)
coo_tensor.rowidx.long(), coo_tensor.colidx.long()
] = coo_tensor.values
torch.testing.assert_close(A1, A2)
A1 = A * (idx == 0)
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
A * (idx == 0), A2, rtol=0.05, atol=1.5e-2
n = 2
dim1 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
# dim1 = [7]
# dim2 = [11]
transposed_B = [False, True]
values = list(product(dim1, dim2, transposed_B))
names = ["dim1_{}_dim2_{}_transposed_B_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names)
def test_spmm_coo(dim1, dim2, transposed_B):
threshold = 1.5
dim3 = torch.randint(32, 128, size=(1,)).item()
# dim3 = 17
for i in range(k):
A = torch.randn(dim1, dim2).cuda().half()
if transposed_B:
B = torch.randn(dim3, dim2).cuda().half()
B = torch.randn(dim2, dim3).cuda().half()
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(
A.shape[0], A.shape[1], nnz,,, values
A2 = A * idx
if transposed_B:
out2 = F.spmm_coo(cooA, B.t())
out1 = torch.matmul(A2, B.t())
out2 = F.spmm_coo(cooA, B)
out1 = torch.matmul(A2, B)
assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)
def test_spmm_bench():
batch = 2
model = 1024 * 1
hidden = model * 4
seq = 1024
dim1 = batch * seq
dim2 = model
dim3 = hidden
threshold = 4
A = torch.randn(dim1, dim2, device="cuda").half()
B = torch.randn(dim2, dim3, device="cuda").half()
for i in range(10):
C1 = bnb.matmul(A, B.t())
t0 = time.time()
for i in range(k):
C1 = bnb.matmul(A, B.t())
t8 = time.time() - t0
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
print(nnz / idx.numel())
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(
A.shape[0], A.shape[1], nnz,,, values
for i in range(10):
out2 = F.spmm_coo(cooA, B)
t0 = time.time()
for i in range(k):
out2 = F.spmm_coo(cooA, B)
tsp = time.time() - t0
print(tsp, t8)
print(tsp / t8)
n = 2
dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2))
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_integrated_sparse_decomp(dim1, dim2):
threshold = 3.0
formatB = "col_turing"
for i in range(k):
A = torch.randn(dim1, dim2).cuda().half()
w1 = torch.randn(dim1, dim2).cuda().half()
out1 = torch.matmul(A, w1.t())
Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
CTw1, Sw1 = F.transform(Cw1, formatB)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
C32A, SA = F.transform(CA, "col32")
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
A, threshold=threshold
C32A, SA = F.transform(CA, "col32")
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
assert coo_tensor is not None
out4 = F.spmm_coo(coo_tensor, w1.t())
out5 = out3 + out4
err1 = torch.abs(out1 - out2).mean().item()
err2 = torch.abs(out1 - out5).mean().item()
assert err2 < err1
def test_matmuls():
a = torch.randn(256, 512).half().cuda()
b = torch.randn(256, 512).half().cuda()
c1 = torch.matmul(a, b.t())
c2 = bnb.matmul(a, b)
c3 = bnb.matmul_cublas(a, b.t())
err1 = torch.abs(c1 - c2).mean().item()
err2 = torch.abs(c1 - c3).mean().item()
assert err1 < 0.2
assert err2 < 0.2
print(err1, err2)
n = 2
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1 * 2048]
dim2 = [12288]
# dim1 = [32]
# dim2 = [32]
# dtype = [torch.float16, torch.int8]
dtype = [torch.float16]
out_function = ["zeros", "ones"]
values = list(product(dim1, dim2, dtype, out_function))
names = [
"dim1_{}_dim2_{}_dtype_{}_out_func_{}".format(*vals) for vals in values
@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names)
def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
out_func = getattr(torch, out_func)
threshold = 3.3
# threshold = 2.8
# threshold = 0.0
A = torch.randn(dim1, dim2, device="cuda").half()
if dtype == torch.float16:
B = torch.randn(dim2, dim2 * 4, device="cuda").half()
B = torch.randn(dim2, dim2 * 4, device="cuda").half()
B, SB = F.vectorwise_quant(B, quant_type="linear")
# B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(
A.shape[0], A.shape[1], nnz,,, values
A2 = A * idx
out1 = torch.matmul(A2.half(), B.half())
out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
out1 += out.clone()
out2 = F.spmm_coo_very_sparse(cooA, B, out=out)
# print(B)
# print(out1)
# print(out2)
p = 200 / (2048 * 12288 * 4)
n = out1.numel()
count = math.ceil(p * n)
std = out1.std()
out1 /= std
out2 /= std
out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count
# assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
idx_col = torch.randint(0, A2.shape[-1], size=(15,))
# torch.testing.assert_close(out1, out2.half(), rtol=0.05, atol=0.001)
# Bt = torch.randn(dim2*4, dim2, device='cuda').half()
# torch.cuda.synchronize()
# t0 = time.time()
# print(A2.shape, B.shape)
# for i in range(100):
# #out3 = F.spmm_coo(cooA, Bt.t())
# #out2 = F.spmm_coo(cooA, B)
# #out2 = F.spmm_coo_very_sparse(cooA, B)
# #out1 = torch.matmul(A, Bt.t())
# torch.cuda.synchronize()
# print(time.time() - t0)
def test_coo2csr():
threshold = 1
A = torch.randn(128, 128).half().cuda()
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(
A.shape[0], A.shape[1], nnz,,, values
A2 = A * idx
csrA = F.coo2csr(cooA)
counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
assert counts.numel() == A.shape[0]
torch.testing.assert_close(counts.long(), (A2 != 0).sum(1))
idx = A2 != 0
torch.testing.assert_close(A2[idx], csrA.values)
def test_coo2csc():
threshold = 1
A = torch.randn(128, 128).half().cuda()
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(
A.shape[0], A.shape[1], nnz,,, values
A2 = A * idx
cscA = F.coo2csc(cooA)
counts = cscA.colptr[1:] - cscA.colptr[:-1]
assert counts.numel() == A.shape[1]
torch.testing.assert_close(counts.long(), (A2 != 0).sum(0))
# torch uses row-major -> use transpose to transfer to col-major
idx = A2.t() != 0
torch.testing.assert_close(A2.t()[idx], cscA.values)
n = 2
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1 = [1 * 2048]
# dim2 = [12288]
dim2 = [2048]
# dim1 = [2]
# dim2 = [2]
dtype = [torch.int8]
values = list(product(dim1, dim2, dtype))
names = ["dim1_{}_dim2_{}_dtype_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names)
def test_spmm_coo_dequant(dim1, dim2, dtype):
threshold = 6.0
# threshold = 2.8
# threshold = 0.0
A = torch.randn(dim1, dim2, device="cuda").half()
B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16)
Bt = B.t().contiguous()
CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
rowidx = torch.randint(0, A.shape[-1], size=(15,))
A[:, rowidx] = 8.0
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
cooA = F.COOSparseTensor(
A.shape[0], A.shape[1], nnz,,, values
A2 = A * idx
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
out1 = torch.matmul(A2, B.half())
out3 = F.spmm_coo_very_sparse(cooA, CBt.half())
out3 = out3 * statsBt.half() / 127
values, counts = torch.unique(cooA.rowidx, return_counts=True)
offset = counts.cumsum(0).int()
max_count, max_idx = torch.sort(counts, descending=True)
torch.testing.assert_close(out2, out3, rtol=0.05, atol=0.001)
p = 200 / (2048 * 12288 * 4)
n = out1.numel()
count = math.ceil(p * n)
assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(100):
# out2 = F.spmm_coo_very_sparse(cooA, B)
# torch.cuda.synchronize()
# print('fp16', time.time() - t0)
t0 = time.time()
for i in range(100):
out2 = F.spmm_coo(cooA, B)
print("cusparse fp16", time.time() - t0)
t0 = time.time()
for i in range(100):
out2 = F.spmm_coo_very_sparse(cooA, CBt)
print("int8", time.time() - t0)
t0 = time.time()
for i in range(100):
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
print("int8+dequant", time.time() - t0)
t0 = time.time()
for i in range(100):
out2 = torch.matmul(A, B)
print("matmul", time.time() - t0)
t0 = time.time()
for i in range(100):
out1 = bnb.matmul(A, Bt)
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
out = out1 + out2
print("sparse+ matmul", time.time() - t0)
t0 = time.time()
for i in range(100):
out1 = bnb.matmul(A, Bt)
torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1)
print("partial matmul", time.time() - t0)
t0 = time.time()
for i in range(100):
out1 = bnb.matmul(A, Bt)
print("partial matmul", time.time() - t0)
batch_size = 1
seqdim = 1
values = []
#values.append((batch_size, seqdim, 768, 4 * 768))
#values.append((batch_size, seqdim, 1024, 4*1024))
#values.append((batch_size, seqdim, 1536, 4*1536))
#values.append((batch_size, seqdim, 2048, 4*2048))
#values.append((batch_size, seqdim, 2560, 4*2560))
#values.append((batch_size, seqdim, 4096, 4*4096))
#values.append((batch_size, seqdim, 5120, 4*5120))
values.append((batch_size, seqdim, 6656, 4*6656))
#values.append((batch_size, seqdim, 8192, 4*8192))
#values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 12288, 4*12288))
names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_matmul(batch, seq, model, hidden):
iters = 1000
formatB = F.get_special_format_str()
A = torch.randn(batch, seq, model, device="cuda").half()
B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
B_fp4, state = F.quantize_fp4(B)
B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True)
B_nf4, state_nf4 = F.quantize_nf4(B)
B_nf4_c, state_nf4_c = F.quantize_nf4(B, compress_statistics=True)
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half()
outliers = torch.randint(0, model, size=(5,)).cuda()
A[:, :, outliers] = 8.0
linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half())
linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
# warmup
for i in range(iters):
torch.matmul(A, B.t())
t0 = time.time()
for i in range(iters):
torch.matmul(A, B.t())
print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
#t0 = time.time()
#for i in range(iters):
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
#print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
#t0 = time.time()
#for i in range(iters):
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
#print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
t0 = time.time()
for i in range(iters):
bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
t0 = time.time()
for i in range(iters):
bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c)
print( f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
#t0 = time.time()
#for i in range(iters):
# bnb.matmul(A, B)
#print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#t0 = time.time()
#for i in range(iters):
# bnb.matmul(A, B, threshold=6.0)
#print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
#C32A, SA = F.transform(CA, "col32")
#CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
#CxB, SB = F.transform(CB, to_order=formatB)
#t0 = time.time()
#for i in range(iters):
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
#print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#BA, statsB = F.vectorwise_quant(B, dim=1)
#CxB, SB = F.nvidia_transform(CB, to_order=formatB)
#t0 = time.time()
#for i in range(iters):
# A2 = A.view(-1, A.shape[-1]).contiguous()
# CA, statsA = F.vectorwise_quant(A2, dim=1)
# C32A, SA = F.nvidia_transform(CA, "col32")
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
# F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
#print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
#CxB, SB = F.nvidia_transform(CB, to_order=formatB)
#t0 = time.time()
#for i in range(iters):
# A2 = A.view(-1, A.shape[-1]).contiguous()
# CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
# C32A, SA = F.nvidia_transform(CA, "col32")
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
# out = Cout * statsB * statsA * (1.0 / (127 * 127))
#print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#t0 = time.time()
#for i in range(iters):
# linear8bit(A)
#print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#t0 = time.time()
#for i in range(iters):
# linearMixedBit(A)
#print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#t0 = time.time()
#for i in range(iters):
# linear8bit_train(A)
#print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#t0 = time.time()
#for i in range(iters):
# linear8bit_train(A)
#print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
def test_zeropoint():
def quant_zp(x):
dtype = x.dtype
x = x.float()
dyna = x.max() - x.min()
if dyna == 0:
dyna = 1
qx = 254.0 / dyna
minx = x.min()
# zpx = torch.round(minx* qx)
# zpx = 127 - torch.round(x.max()* qx)
zpx = torch.round(x.min() * qx) - 127
x = (qx * x) + zpx
return x, qx, zpx
batch = 2
seq = 512
model = 1024
hidden = 4 * model
A = torch.randn(batch * seq, model, device="cuda").half() * 0.1
B = torch.randn(model, hidden, device="cuda").half() * 0.1
C0 = torch.matmul(A, B)
# A, SA = F.vectorwise_quant(A, quant_type='linear')
# B, SB = F.vectorwise_quant(B, quant_type='linear')
A = A.float()
B = B.float()
C1 = torch.matmul(A, B)
C3 = bnb.matmul(A.half(), B.t().contiguous().half())
zp = 1
# C2 = torch.matmul(A-zp, B)
# C2 += B.sum(0).view(1, -1)*zp
C2 = torch.matmul(A, B - zp)
C2 -= A.sum(1).view(-1, 1) * zp
ca, cqa, cza = quant_zp(A)
print(ca.min(), ca.max())
print((ca - cza).min(), (ca - cza).max())
zp = 1
scale = 2.0
C5 = torch.matmul((A * scale) - zp, B)
C5 += B.sum(0) * zp
C5 /= scale
CA, qa, zpa = quant_zp(A)
C4 = torch.matmul(CA, B)
C4 -= B.sum(0) * zpa
C4 /= qa
zpb = 1
zpa = 1
qa = 2
qb = 2
C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb)
C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
C6 -= zpa * zpb * A.shape[1]
C6 /= qa * qb
CA, qa, zpa = quant_zp(A)
CB, qb, zpb = quant_zp(B)
C7 = torch.matmul(CA, CB)
C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
C7 -= zpa * zpb * A.shape[1]
C7 /= qa * qb
# print(C0.flatten()[:10])
err1 = torch.abs(C1 - C2).mean().item()
err2 = torch.abs(C1 - C3).mean().item()
err3 = torch.abs(C1 - C4).mean().item()
err4 = torch.abs(C1 - C5).mean().item()
err5 = torch.abs(C1 - C6).mean().item()
err6 = torch.abs(C1 - C7).mean().item()
print(err1, err2, err3, err4, err5, err6)
def test_extract_outliers():
for i in range(k):
shapeA = (4096, 4096 * 4)
idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
# idx = torch.Tensor([0]).int().cuda()
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
outliers1 = A[:, idx.long()]
CA, SA = F.transform(A, "col_turing")
outliers2 = F.extract_outliers(CA, SA, idx)
assert outliers2.shape[0] == shapeA[0]
assert outliers2.shape[1] == idx.numel()
torch.testing.assert_close(outliers1, outliers2)
CA, SA = F.transform(A, "col_ampere")
outliers2 = F.extract_outliers(CA, SA, idx)
assert outliers2.shape[0] == shapeA[0]
assert outliers2.shape[1] == idx.numel()
torch.testing.assert_close(outliers1, outliers2)
def test_blockwise_cpu_large():
diffs = []
reldiffs = []
batch = 128
seq = 128
for hidden in [128]:#, 14336]:
for blocksize in [4096, 16384]:
for i in range(2):
A1 = torch.randn(batch, seq, hidden, device='cpu')
t0 = time.time()
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
print(time.time() - t0)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
assert diffs[-1] < 0.011
# print(sum(diffs)/len(diffs))
# print(sum(reldiffs)/len(reldiffs))
def test_fp8_quant():
for e_bits in range(1, 7):
p_bits = 7-e_bits
code = F.create_fp8_map(True, e_bits, p_bits).cuda()
abserr = []
relerr = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8)
#assert diff < 0.0075
abserr = []
relerr = []
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8)
#assert diff < 0.0075
abserr = []
relerr = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8)
#assert diff < 0.0075
#print(3, sum(abserr)/len(abserr))
#print(3, sum(relerr)/len(relerr))
def test_few_bit_quant():
for bits in range(2, 9):
#print('='*30, bits, '='*30)
for method in ['linear', 'fp8', 'dynamic', 'quantile']:
abserrs = []
relerrs = []
code = None
if method == 'linear':
code = F.create_linear_map(True, total_bits=bits).cuda()
elif method == 'fp8':
ebits = math.ceil(bits/2)
pbits = bits-ebits-1
code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
elif method == 'dynamic':
code = F.create_dynamic_map(True, bits-0, bits).cuda()
elif method == 'quantile':
values = torch.randn(2048, 2048, device='cuda')
code = F.create_quantile_map(values, bits).cuda()
# for some data types we have no zero
# for some data types we have one zero
# for some data types we have two zeros
assert torch.unique(code).numel() in [2**bits, 2**bits-1], f'bits: {bits}, method: {method}'
#print(method, (code==0).sum())
assert code.numel() == 256
for i in range(10):
values = torch.randn(1, 32, device='cuda')
values /= values.abs().max()
#values[values.abs() < 1e-6] += 1e-5
q1 = []
v1 = []
for v in values[0]:
idx = torch.abs(v-code).argmin()
q1 = torch.Tensor(q1).cuda()
v1 = torch.Tensor(v1).cuda()
q2, S2 = F.quantize_blockwise(values, code=code)
v2 = F.dequantize_blockwise(q2, S2)
idx = torch.isclose(,
err2 = torch.abs(v2-values)
if idx.sum():
# some weird cases
err1 = torch.abs(v1-values).mean()
#assert err2.mean() <= err1
torch.testing.assert_close(q1, q2)
#print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
#assert False
def test_kbit_quantile_estimation():
for i in range(100):
data = torch.randn(1024, 1024, device='cuda')
for bits in range(2, 9):
p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits)
val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
err = torch.abs(val1-val2).mean()
assert err < 0.038
for i in range(100):
data = torch.randn(1024, 1024, device='cuda')
for bits in range(2, 4):
total_values = 2**bits-1
p = np.linspace(0, 1, 2*total_values+1)
idx = np.arange(1, 2*total_values+1, 2)
p = p[idx]
offset = 1/(2*total_values)
p = np.linspace(offset, 1-offset, total_values)
val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, num_quantiles=2**bits-1)
err = torch.abs(val1-val2).mean()
assert err < 0.035
def test_bench_dequantization():
a = torch.rand(1024, 1024, device='cuda').half()
code =F.create_fp8_map(True, 3, 0, 4).cuda()
qa, SA = F.quantize_blockwise(a, code=code)
max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000
t0 = time.time()
for i in range(100):
qa, SA = F.quantize_blockwise(a)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
def test_fp4_quant(dtype):
vals = list(product([0, 1], repeat=4))
code = {}
for bits in vals:
result = 0
bias = 3
sign, e1, e2, p1 = bits
idx = sign*8 + e1*4 + e2*2 + p1*1
sign = -1.0 if sign else 1.0
exp = e1*2 + e2*1
if exp == 0:
# sub-normal
if p1 == 0: result = 0
else: result = sign*0.0625
# normal
exp = 2**(-exp + bias + 1)
frac = 1.5 if p1 else 1.0
result = sign*exp*frac
code[idx] = result
A1 = torch.randn(1024, 1024, device='cuda', dtype=dtype)
qa, SA = F.quantize_fp4(A1, blocksize=64)
A2 = F.dequantize_fp4(qa, SA)
err = (A1 - A2).abs().float()
relerr = (err/(A1.abs().float()+1e-8)).mean()
idx = err > 1.0
err = err.mean()
assert A2.dtype == dtype
assert err.item() < 0.1
assert relerr.item() < 0.28
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
def test_4bit_compressed_stats(quant_type):
for blocksize in [128, 64]:
errs1 = []
errs2 = []
for i in range(10):
A1 = torch.randn(1024, 1024, device='cuda').half()
q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
q3, SA3= F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type)
A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type)
err = (A1 - A2).abs().float()
relerr = (err/(A1.abs().float()+1e-15)).mean()
err = err.mean()
assert err.item() < 0.11
assert relerr.item() < 0.28
err = (A1 - A3).abs().float()
relerr = (err/(A1.abs().float()+1e-15)).mean()
err = err.mean()
assert err.item() < 0.11
assert relerr.item() < 0.28
#print(sum(errs1)/len(errs1), blocksize, quant_type)
#print(sum(errs2)/len(errs2), blocksize, quant_type)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
#@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
@pytest.mark.parametrize("quant_type", ['nf4'])
def test_bench_4bit_dequant(quant_type):
blocksize = 256
a = torch.rand(1024*12*4, 1024*12, device='cuda').half()
qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type)
input_size = a.numel()/2
output_size = a.numel()*2
num_bytes = input_size+output_size
GB = num_bytes/1e9
max_theoretical_s = GB/768
b = torch.randn(128, 1024*12, device='cuda').half()
iters = 100
t0 = time.time()
for i in range(iters):
F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
#t0 = time.time()
#for i in range(iters):
# torch.matmul(b, a.t())
def test_normal_map_tree():
code = F.create_normal_map()
values =code[:8].tolist() + code[-8:].tolist()
num_pivots = 1
while num_pivots <16:
idx = list(range(16//num_pivots//2, 16, 16//num_pivots))
num_pivots *= 2
pivots = []
for i in idx:
@pytest.mark.parametrize("double_quant", [True, False], ids=['DQ_True', 'DQ_False'])
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
@pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'], ids=['fc1', 'fc2', 'attn', 'attn_packed'])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
def test_gemv_4bit(dtype, storage_type, double_quant, kind):
for dim in [128, 256, 512, 1024, 2048, 4096, 6144]:
#for dim in [4*1024]:
#for dim in [1*128]:
errs1 = []
errs2 = []
errs3 = []
relerrs1 = []
relerrs2 = []
relerrs3 = []
max_errs1 = []
max_errs2 = []
max_errs3 = []
for i in range(100):
if kind == 'fc1':
A = torch.randn(1, dim, dtype=dtype, device='cuda')
B = torch.randn(dim*4, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
elif kind == 'fc2':
A = torch.randn(1, 4*dim, dtype=dtype, device='cuda')
B = torch.randn(dim, 4*dim, dtype=dtype, device='cuda')/math.sqrt(dim)
elif kind == 'attn':
A = torch.randn(1, dim, dtype=dtype, device='cuda')
B = torch.randn(dim, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
elif kind == 'attn_packed':
A = torch.randn(1, dim, dtype=dtype, device='cuda')
B = torch.randn(dim*3, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
C3 = torch.matmul(A, B.t())
C2 = F.gemv_4bit(A, qB.t(), state=state)
A.requires_grad = True
C1 = bnb.matmul_4bit(A, qB.t(), state)
err1 = (C1-C2).abs().float()
err2 = (C3-C2).abs().float()
err3 = (C3-C1).abs().float()
mag1 = torch.abs(C1).float()+1e-5
mag2 = torch.abs(C3).float()+1e-5
mag3 = torch.abs(C3).float()+1e-5
relerr1 = err1/mag1
relerr2 = err2/mag2
relerr3 = err3/mag3
max_err1 = err1.max()
max_err2 = err2.max()
max_err3 = err3.max()
c = int(C1.numel()*0.0014*(dim/256))+1
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
err1 = sum(errs1)/len(errs1)/math.sqrt(dim)
err2 = sum(errs2)/len(errs2)/math.sqrt(dim)
err3 = sum(errs3)/len(errs3)/math.sqrt(dim)
relerr1 = sum(relerrs1)/len(relerrs1)/math.sqrt(dim)
relerr2 = sum(relerrs2)/len(relerrs2)/math.sqrt(dim)
relerr3 = sum(relerrs3)/len(relerrs3)/math.sqrt(dim)
maxerr1 = sum(max_errs1)/len(max_errs1)/math.sqrt(dim)
maxerr2 = sum(max_errs2)/len(max_errs2)/math.sqrt(dim)
maxerr3 = sum(max_errs3)/len(max_errs3)/math.sqrt(dim)
absratio = err2/err3
relratio = relerr2/relerr3
maxratio = relerr2/relerr3
# for debugging if the tests fails
#print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
#print(f'inference vs training abs: {err1}')
#print(f'inference vs training rel: {relerr1}')
#print(f'inference vs training max: {maxerr1}')
#print(f'inference vs training vs torch err ratio abs: {absratio}')
#print(f'inference vs training vs torch err ratio rel: {relratio}')
#print(f'inference vs training vs torch err ratio max: {maxratio}')
if dtype == torch.float16:
if dim <= 512:
assert err1 < 7e-5
assert relerr1 < 0.0008
assert err1 < 6e-5
assert relerr1 < 2e-4
assert absratio < 1.005 and absratio > 0.995
assert relratio < 1.005 and relratio > 0.995
assert maxratio < 1.005 and maxratio > 0.995
elif dtype == torch.float32:
if dim <= 512:
assert err1 < 5e-8
assert relerr1 < 1e-6
assert maxerr1 < 1e-7
assert err1 < 5e-8
assert relerr1 < 8e-6
assert maxerr1 < 1e-7
assert absratio < 1.005 and absratio > 0.995
assert relratio < 1.005 and relratio > 0.995
assert maxratio < 1.005 and maxratio > 0.995
elif dtype == torch.bfloat16:
if dim <= 512:
assert err1 < 5e-4
assert relerr1 < 0.007
assert maxerr1 < 0.015
assert err1 < 2e-4
assert relerr1 < 0.002
assert maxerr1 < 0.0012
assert absratio < 1.005 and absratio > 0.995
assert relratio < 1.04 and relratio > 0.96
assert maxratio < 1.02 and maxratio > 0.98
@pytest.mark.skip("Row scale has some bugs for ampere")
def test_managed():
n = 32*10
A = F.get_paged(n, n, dtype=torch.float32)
B = F.get_paged(n, n, dtype=torch.uint8)
B2 = F.get_paged(n, n, dtype=torch.float32)
assert A.is_paged
assert B.is_paged
assert A.page_deviceid==0
assert B.page_deviceid==0
F.fill(A, 17.0)
F.fill(B, 17)
F.fill(B2, 2)
assert (A==17).sum().item() == n*n
assert (B==17).sum().item() == n*n
C = A*B.float()
assert (C==289).sum().item() == n*n
F._mul(A, B2)
F._mul(A, B2)
F._mul(A, B2)
assert (A==17*(2**3)).sum().item() == n*n
# F.prefetch_tensor(A)
# F.prefetch_tensor(B)
# F.fill(B2, 17.0)
# F._mul(A, B2)
# F.prefetch_tensor(A, to_cpu=True)
# F.prefetch_tensor(B, to_cpu=True)
# F.prefetch_tensor(B2, to_cpu=True)
# torch.cuda.synchronize()
# assert (A==17).sum().item() == n*n
# torch.testing.assert_close(A, torch.ones(A.shape)*289)