2022-07-22 21:41:05 +00:00
|
|
|
import math
|
|
|
|
import random
|
|
|
|
import time
|
2021-10-06 02:16:20 +00:00
|
|
|
from itertools import product
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
import einops
|
|
|
|
import pytest
|
|
|
|
import torch
|
2022-11-06 21:05:25 +00:00
|
|
|
import numpy as np
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
import bitsandbytes as bnb
|
2021-10-06 02:16:20 +00:00
|
|
|
from bitsandbytes import functional as F
|
2022-11-06 21:05:25 +00:00
|
|
|
from scipy.stats import norm
|
2021-10-06 02:16:20 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
torch.set_printoptions(
|
2022-11-06 19:59:37 +00:00
|
|
|
precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
|
2022-08-01 10:31:48 +00:00
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
k = 20
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2023-05-02 15:58:59 +00:00
|
|
|
def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True):
|
2022-07-22 21:41:05 +00:00
|
|
|
idx = torch.isclose(a, b, rtol, atol)
|
2022-08-01 10:31:48 +00:00
|
|
|
sumval = (idx == 0).sum().item()
|
2022-07-22 21:41:05 +00:00
|
|
|
if sumval > count:
|
2023-05-02 15:58:59 +00:00
|
|
|
if throw:
|
|
|
|
print(f"Too many values not close: assert {sumval} < {count}")
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(a, b, rtol, atol)
|
2023-05-02 15:58:59 +00:00
|
|
|
|
|
|
|
return sumval
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
class FFN(torch.nn.Module):
|
|
|
|
def __init__(self, input_features, hidden_size, bias=True):
|
2022-10-27 11:14:13 +00:00
|
|
|
super().__init__()
|
2022-07-22 21:41:05 +00:00
|
|
|
self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias)
|
|
|
|
self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias)
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
torch.nn.init.xavier_uniform_(self.fc1.weight)
|
|
|
|
torch.nn.init.xavier_uniform_(self.fc2.weight)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = torch.relu(self.fc1(x))
|
|
|
|
x = self.fc2(x)
|
|
|
|
return x
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-10-27 11:14:13 +00:00
|
|
|
class Timer:
|
2022-07-22 21:41:05 +00:00
|
|
|
def __init__(self):
|
|
|
|
self.starts = {}
|
|
|
|
self.ends = {}
|
|
|
|
self.agg = {}
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
def tick(self, name="default"):
|
2022-07-22 21:41:05 +00:00
|
|
|
if name not in self.starts:
|
|
|
|
self.starts[name] = torch.cuda.Event(enable_timing=True)
|
|
|
|
self.ends[name] = torch.cuda.Event(enable_timing=True)
|
|
|
|
self.starts[name].record()
|
|
|
|
else:
|
|
|
|
ms = self.tock(name, evict=True, print_ms=False)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
def tock(self, name="default", evict=True, print_ms=True):
|
2022-07-22 21:41:05 +00:00
|
|
|
if name in self.ends:
|
|
|
|
self.ends[name].record()
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
ms = self.starts[name].elapsed_time(self.ends[name])
|
2022-08-01 10:31:48 +00:00
|
|
|
if name not in self.agg:
|
|
|
|
self.agg[name] = 0.0
|
2022-07-22 21:41:05 +00:00
|
|
|
self.agg[name] += ms
|
|
|
|
if evict:
|
|
|
|
self.starts.pop(name)
|
|
|
|
self.ends.pop(name)
|
|
|
|
|
|
|
|
if print_ms and name in self.agg:
|
2022-10-27 11:14:13 +00:00
|
|
|
print(f"{name} took: {self.agg[name] / 1000.0:.5f}s")
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
return self.agg[name]
|
|
|
|
|
|
|
|
def reset(self):
|
2022-08-01 10:31:48 +00:00
|
|
|
self.starts = {}
|
2022-07-22 21:41:05 +00:00
|
|
|
self.ends = {}
|
|
|
|
self.agg = {}
|
2022-08-01 10:31:48 +00:00
|
|
|
print("Resetting benchmark data")
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2021-10-06 02:16:20 +00:00
|
|
|
def setup():
|
|
|
|
pass
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2021-10-06 02:16:20 +00:00
|
|
|
def teardown():
|
|
|
|
pass
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-08-01 16:32:47 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"dtype", [torch.float32, torch.float16], ids=["float", "half"]
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
def test_estimate_quantiles(dtype):
|
2022-08-01 10:31:48 +00:00
|
|
|
A = torch.rand(1024, 1024, device="cuda")
|
2021-10-06 02:16:20 +00:00
|
|
|
A = A.to(dtype)
|
|
|
|
code = F.estimate_quantiles(A)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2)
|
2021-10-06 02:16:20 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
A = torch.randn(1024, 1024, device="cuda")
|
2021-10-06 02:16:20 +00:00
|
|
|
A = A.to(dtype)
|
|
|
|
code = F.estimate_quantiles(A)
|
|
|
|
|
|
|
|
quantiles = torch.quantile(A.float(), percs)
|
2022-08-01 10:31:48 +00:00
|
|
|
diff = torch.abs(code - quantiles)
|
2021-10-06 02:16:20 +00:00
|
|
|
assert (diff > 5e-02).sum().item() == 0
|
|
|
|
|
|
|
|
|
|
|
|
def test_quantile_quantization():
|
|
|
|
for i in range(100):
|
2022-08-01 10:31:48 +00:00
|
|
|
A1 = torch.randn(1024, 1024, device="cuda")
|
2021-10-06 02:16:20 +00:00
|
|
|
code = F.estimate_quantiles(A1)
|
|
|
|
C = F.quantize_no_absmax(A1, code)
|
|
|
|
A2 = F.dequantize_no_absmax(C, code)
|
2022-08-01 10:31:48 +00:00
|
|
|
diff = torch.abs(A1 - A2).mean().item()
|
2021-10-06 02:16:20 +00:00
|
|
|
assert diff < 0.0075
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
A1 = torch.rand(1024, 1024, device="cuda")
|
2021-10-06 02:16:20 +00:00
|
|
|
code = F.estimate_quantiles(A1)
|
|
|
|
C = F.quantize_no_absmax(A1, code)
|
|
|
|
A2 = F.dequantize_no_absmax(C, code)
|
2022-08-01 10:31:48 +00:00
|
|
|
diff = torch.abs(A1 - A2).mean().item()
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0)
|
2021-10-06 02:16:20 +00:00
|
|
|
assert diff < 0.001
|
|
|
|
|
|
|
|
|
|
|
|
def test_dynamic_quantization():
|
|
|
|
diffs = []
|
|
|
|
reldiffs = []
|
|
|
|
for i in range(100):
|
2022-08-01 10:31:48 +00:00
|
|
|
A1 = torch.randn(1024, 1024, device="cuda")
|
2021-10-06 02:16:20 +00:00
|
|
|
C, S = F.quantize(A1)
|
|
|
|
A2 = F.dequantize(C, S)
|
2022-08-01 10:31:48 +00:00
|
|
|
diff = torch.abs(A1 - A2)
|
|
|
|
reldiff = diff / torch.abs(A1 + 1e-8)
|
2021-10-06 02:16:20 +00:00
|
|
|
diffs.append(diff.mean().item())
|
|
|
|
reldiffs.append(reldiff.mean().item())
|
|
|
|
assert diff.mean().item() < 0.0135
|
2022-08-01 10:31:48 +00:00
|
|
|
# print(sum(diffs)/len(diffs))
|
|
|
|
# print(sum(reldiffs)/len(reldiffs))
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
for i in range(100):
|
2022-08-01 10:31:48 +00:00
|
|
|
A1 = torch.rand(1024, 1024, device="cuda")
|
2021-10-06 02:16:20 +00:00
|
|
|
C, S = F.quantize(A1)
|
|
|
|
A2 = F.dequantize(C, S)
|
2022-08-01 10:31:48 +00:00
|
|
|
diff = torch.abs(A1 - A2).mean().item()
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
|
2021-10-06 02:16:20 +00:00
|
|
|
assert diff < 0.004
|
|
|
|
|
|
|
|
|
2023-04-19 18:48:47 +00:00
|
|
|
|
2023-07-05 02:58:31 +00:00
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
|
2023-04-19 18:48:47 +00:00
|
|
|
@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"])
|
|
|
|
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
|
2023-07-05 02:58:31 +00:00
|
|
|
def test_dynamic_blockwise_quantization(dtype, nested, blocksize):
|
2022-11-07 00:27:48 +00:00
|
|
|
#print('')
|
2023-04-19 18:48:47 +00:00
|
|
|
diffs = []
|
|
|
|
reldiffs = []
|
|
|
|
for i in range(100):
|
2023-07-05 02:58:31 +00:00
|
|
|
A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
|
2023-04-19 18:48:47 +00:00
|
|
|
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
|
|
|
|
A2 = F.dequantize_blockwise(C, S)
|
2023-07-05 02:58:31 +00:00
|
|
|
diff = torch.abs(A1 - A2).float()
|
|
|
|
reldiff = diff / torch.abs(A1.float() + 1e-8)
|
2023-04-19 18:48:47 +00:00
|
|
|
diffs.append(diff.mean().item())
|
|
|
|
reldiffs.append(reldiff.mean().item())
|
|
|
|
abserr = sum(diffs)/len(diffs)
|
|
|
|
relerr = sum(reldiffs)/len(reldiffs)
|
2023-07-05 02:58:31 +00:00
|
|
|
#print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
|
|
|
|
#print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
|
2023-04-19 18:48:47 +00:00
|
|
|
assert abserr < 0.011
|
|
|
|
assert relerr < 0.018
|
2023-07-05 02:58:31 +00:00
|
|
|
assert A2.dtype == dtype
|
2023-04-19 18:48:47 +00:00
|
|
|
|
|
|
|
diffs = []
|
|
|
|
for i in range(100):
|
2023-07-05 02:58:31 +00:00
|
|
|
A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype)
|
2023-04-19 18:48:47 +00:00
|
|
|
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
|
|
|
|
A2 = F.dequantize_blockwise(C, S)
|
2023-07-05 02:58:31 +00:00
|
|
|
diff = torch.abs(A1 - A2).float()
|
|
|
|
reldiff = diff / torch.abs(A1.float() + 1e-8)
|
2023-04-19 18:48:47 +00:00
|
|
|
diffs.append(diff.mean().item())
|
|
|
|
reldiffs.append(reldiff.mean().item())
|
2023-05-07 20:34:03 +00:00
|
|
|
#torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
|
2023-04-19 18:48:47 +00:00
|
|
|
abserr = sum(diffs)/len(diffs)
|
|
|
|
relerr = sum(reldiffs)/len(reldiffs)
|
|
|
|
assert abserr < 0.0035
|
|
|
|
assert relerr < 0.015
|
2023-07-05 02:58:31 +00:00
|
|
|
assert A2.dtype == dtype
|
2023-05-06 21:59:29 +00:00
|
|
|
#print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
|
|
|
|
#print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
|
2022-08-01 16:32:47 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"gtype", [torch.float32, torch.float16], ids=["float", "half"]
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
def test_percentile_clipping(gtype):
|
2022-08-01 10:31:48 +00:00
|
|
|
gnorm_vec1 = torch.zeros(100, device="cuda")
|
|
|
|
gnorm_vec2 = torch.zeros(100, device="cuda")
|
2021-10-06 02:16:20 +00:00
|
|
|
n = 4
|
|
|
|
step = 0
|
2022-08-01 10:31:48 +00:00
|
|
|
percentile = 5
|
2022-07-22 21:41:05 +00:00
|
|
|
for i in range(k):
|
2021-10-06 02:16:20 +00:00
|
|
|
step += 1
|
2022-08-01 10:31:48 +00:00
|
|
|
g = torch.randn(n, n, dtype=gtype, device="cuda")
|
|
|
|
gnorm1, clip2, gnorm_scale = F.percentile_clipping(
|
|
|
|
g, gnorm_vec2, step, percentile=percentile
|
|
|
|
)
|
|
|
|
assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
gnorm2 = torch.norm(g.float())
|
|
|
|
if step == 1:
|
|
|
|
gnorm_vec1[:] = gnorm2
|
|
|
|
else:
|
|
|
|
gnorm_vec1[step % 100] = gnorm2
|
|
|
|
|
|
|
|
vals, idx = torch.sort(gnorm_vec1)
|
|
|
|
clip1 = vals[percentile]
|
|
|
|
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2))
|
|
|
|
torch.testing.assert_close(clip1, clip2)
|
|
|
|
torch.testing.assert_close(gnorm1, gnorm2)
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
def quant(x):
|
|
|
|
max1 = torch.abs(x).max()
|
2022-08-01 10:31:48 +00:00
|
|
|
x = torch.round(x / max1 * 127)
|
2022-07-22 21:41:05 +00:00
|
|
|
return max1, x.to(torch.int8)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
def dequant(c, maxC):
|
2022-08-01 10:31:48 +00:00
|
|
|
return c.float() * (maxC / 127)
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
def mm_dequant(maxA, maxB, C):
|
2022-08-01 10:31:48 +00:00
|
|
|
return C.float() * (maxA / 127) * (maxB / 127)
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
def quant_multi(x, dim):
|
|
|
|
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
|
2022-08-01 10:31:48 +00:00
|
|
|
max1[max1 == 0] = 1.0
|
|
|
|
x = torch.round(x / max1 * 127)
|
2022-07-22 21:41:05 +00:00
|
|
|
return max1, x.to(torch.int8)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
def quant_multi_chunk(x, dim, chunk_size=32):
|
2022-08-01 10:31:48 +00:00
|
|
|
if dim == 1:
|
|
|
|
x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size)
|
|
|
|
max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True)
|
2022-07-22 21:41:05 +00:00
|
|
|
max1 = torch.tile(max1, (1, 1, x.shape[1]))
|
|
|
|
max1 = max1.view(x.shape)
|
2022-08-01 10:31:48 +00:00
|
|
|
elif dim == 0:
|
|
|
|
x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size)
|
2022-07-22 21:41:05 +00:00
|
|
|
max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True)
|
|
|
|
max1 = torch.tile(max1, (x.shape[0], 1, 1))
|
|
|
|
max1 = max1.view(x.shape)
|
2022-08-01 10:31:48 +00:00
|
|
|
max1[max1 == 0] = 1.0
|
|
|
|
x = torch.round(x / max1 * 127)
|
2022-07-22 21:41:05 +00:00
|
|
|
return max1, x.to(torch.int8)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
def quant_minmax(A):
|
|
|
|
minA = A.min()
|
|
|
|
maxA = A.max()
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
def mean(xx):
|
2022-08-01 10:31:48 +00:00
|
|
|
return sum(xx) / float(len(xx))
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
# dim1 = torch.randint(1,1024*4, size=(4,)).tolist()
|
|
|
|
# dim2 = torch.randint(1,1024*4, size=(4,)).tolist()
|
|
|
|
dim1 = [1024 * 2]
|
|
|
|
dim2 = [1024 * 16]
|
|
|
|
methods = [
|
2022-08-01 16:32:47 +00:00
|
|
|
(
|
|
|
|
lambda x, dim: quant(x),
|
|
|
|
lambda x, dim: quant(x),
|
|
|
|
dequant,
|
|
|
|
dequant,
|
|
|
|
mm_dequant,
|
|
|
|
)
|
2022-08-01 10:31:48 +00:00
|
|
|
]
|
2022-07-22 21:41:05 +00:00
|
|
|
methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant))
|
2022-08-01 10:31:48 +00:00
|
|
|
# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
|
|
|
|
method_names = ["linear", "vectorwise"]
|
2022-07-22 21:41:05 +00:00
|
|
|
batched = [False, True]
|
2022-08-01 10:31:48 +00:00
|
|
|
values = list(product(dim1, dim2, methods, batched))
|
|
|
|
values_names = list(product(dim1, dim2, method_names, batched))
|
|
|
|
names = [
|
2022-10-27 11:14:13 +00:00
|
|
|
"dim1_{}_dim2_{}_quant_{}_batched_{}".format(*vals)
|
2022-08-01 16:32:47 +00:00
|
|
|
for vals in values_names
|
2022-08-01 10:31:48 +00:00
|
|
|
]
|
|
|
|
|
|
|
|
|
2022-08-01 16:32:47 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"dim1, dim2, quant_methods, batched", values, ids=names
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
def test_approx_igemm(dim1, dim2, quant_methods, batched):
|
|
|
|
dim1 = dim1 - (dim1 % 32)
|
|
|
|
dim2 = dim2 - (dim2 % 32)
|
|
|
|
errors = []
|
|
|
|
relerrors = []
|
2023-05-06 21:59:29 +00:00
|
|
|
#print("")
|
2022-07-22 21:41:05 +00:00
|
|
|
for i in range(5):
|
|
|
|
if batched:
|
2022-08-01 10:31:48 +00:00
|
|
|
A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
|
|
|
|
B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda")
|
2022-07-22 21:41:05 +00:00
|
|
|
maxA, Ac = quant_methods[0](A, 2)
|
|
|
|
maxB, Bc = quant_methods[1](B, 1)
|
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda")
|
|
|
|
B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
|
2022-07-22 21:41:05 +00:00
|
|
|
maxA, Ac = quant_methods[0](A, 1)
|
|
|
|
maxB, Bc = quant_methods[1](B, 0)
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(
|
2022-08-01 10:31:48 +00:00
|
|
|
quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
if batched:
|
|
|
|
out2 = torch.bmm(A, B)
|
|
|
|
C = torch.bmm(Ac.float(), Bc.float())
|
|
|
|
else:
|
|
|
|
out2 = torch.mm(A, B)
|
|
|
|
C = F.igemm(Ac, Bc)
|
|
|
|
out = quant_methods[4](maxA, maxB, C)
|
|
|
|
std = out2.std()
|
2022-08-01 10:31:48 +00:00
|
|
|
out /= std
|
|
|
|
out2 /= std
|
|
|
|
err = torch.abs(out - out2)
|
|
|
|
relerr = err / torch.abs(out2)
|
2022-07-22 21:41:05 +00:00
|
|
|
errors.append(err.mean().item())
|
|
|
|
relerrors.append(relerr.mean().item())
|
2023-05-06 21:59:29 +00:00
|
|
|
#print(mean(errors))
|
|
|
|
#print(mean(relerrors))
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
|
2021-10-06 02:16:20 +00:00
|
|
|
def test_stable_embedding():
|
|
|
|
layer = bnb.nn.StableEmbedding(1024, 1024)
|
|
|
|
layer.reset_parameters()
|
|
|
|
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
n = 2
|
2022-08-01 10:31:48 +00:00
|
|
|
hidden_dim = torch.randint(32, 256, size=(n,)).tolist()
|
|
|
|
batch_dim = torch.randint(16, 256, size=(n,)).tolist()
|
|
|
|
seq_dim = torch.randint(16, 256, size=(n,)).tolist()
|
2022-07-22 21:41:05 +00:00
|
|
|
transpose = [(False, False), (False, True), (True, False), (True, True)]
|
2022-08-01 10:31:48 +00:00
|
|
|
values = list(product(hidden_dim, batch_dim, transpose, seq_dim))
|
|
|
|
names = [
|
2022-10-27 11:14:13 +00:00
|
|
|
"hidden_dim_{}_batch_dim_{},transpose_{}_seq_dim_{}".format(*vals)
|
2022-08-01 10:31:48 +00:00
|
|
|
for vals in values
|
|
|
|
]
|
|
|
|
|
|
|
|
|
2022-08-01 16:32:47 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"hidden_dim, batch_dim, transpose, seq_dim", values, ids=names
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
|
|
|
|
hidden_dim = hidden_dim - (hidden_dim % 32)
|
|
|
|
batch_dim = batch_dim - (batch_dim % 16)
|
|
|
|
seq_dim = seq_dim - (seq_dim % 16)
|
|
|
|
for i in range(k):
|
2022-08-01 10:31:48 +00:00
|
|
|
shapeA = (
|
2022-08-01 16:32:47 +00:00
|
|
|
(batch_dim, hidden_dim)
|
|
|
|
if not transpose[0]
|
|
|
|
else (hidden_dim, batch_dim)
|
2022-08-01 10:31:48 +00:00
|
|
|
)
|
|
|
|
shapeB = (
|
|
|
|
(32 * random.randint(1, 4), hidden_dim)
|
|
|
|
if transpose[1]
|
|
|
|
else (hidden_dim, 32 * random.randint(1, 4))
|
|
|
|
)
|
|
|
|
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
|
|
|
|
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
|
2022-07-22 21:41:05 +00:00
|
|
|
if not transpose[0] and not transpose[1]:
|
|
|
|
out2 = torch.matmul(A.float(), B.float())
|
|
|
|
out = F.igemm(A, B)
|
|
|
|
elif not transpose[0] and transpose[1]:
|
|
|
|
out2 = torch.matmul(A.float(), B.t().float())
|
|
|
|
out = F.igemm(A, B.t())
|
|
|
|
elif transpose[0] and not transpose[1]:
|
|
|
|
out2 = torch.matmul(A.t().float(), B.float())
|
|
|
|
out = F.igemm(A.t(), B)
|
|
|
|
elif transpose[0] and transpose[1]:
|
|
|
|
out2 = torch.matmul(A.t().float(), B.t().float())
|
|
|
|
out = F.igemm(A.t(), B.t())
|
2021-10-06 02:16:20 +00:00
|
|
|
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(out.float(), out2)
|
2021-10-06 02:16:20 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
for i in range(k):
|
|
|
|
shapeA = (batch_dim, seq_dim, hidden_dim)
|
2022-08-01 10:31:48 +00:00
|
|
|
shapeB = (
|
|
|
|
(32 * random.randint(1, 4), hidden_dim)
|
|
|
|
if transpose[1]
|
|
|
|
else (hidden_dim, 32 * random.randint(1, 4))
|
|
|
|
)
|
|
|
|
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
|
|
|
|
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
|
2022-07-22 21:41:05 +00:00
|
|
|
if not transpose[0] and not transpose[1]:
|
|
|
|
out2 = torch.matmul(A.float(), B.float())
|
|
|
|
out = F.igemm(A, B)
|
|
|
|
elif not transpose[0] and transpose[1]:
|
|
|
|
out2 = torch.matmul(A.float(), B.t().float())
|
|
|
|
out = F.igemm(A, B.t())
|
|
|
|
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(out.float(), out2)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
|
|
|
|
n = 3
|
2022-08-01 10:31:48 +00:00
|
|
|
seq_dim = torch.randint(32, 512, size=(n,)).tolist()
|
|
|
|
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
|
|
|
|
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
|
|
|
|
values = list(product(seq_dim, hidden_dim, batch_dim))
|
2022-08-01 16:32:47 +00:00
|
|
|
names = [
|
2022-10-27 11:14:13 +00:00
|
|
|
"seq_dim{}_hidden_dim{}_batch_dim{}".format(*vals) for vals in values
|
2022-08-01 16:32:47 +00:00
|
|
|
]
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names)
|
|
|
|
def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
|
|
|
|
seq_dim = seq_dim - (seq_dim % 32)
|
|
|
|
hidden_dim = hidden_dim - (hidden_dim % 32)
|
|
|
|
batch_dim = batch_dim - (batch_dim % 2)
|
|
|
|
for i in range(25):
|
2022-08-01 10:31:48 +00:00
|
|
|
A = torch.randint(
|
|
|
|
-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
|
|
|
|
).to(torch.int8)
|
2022-08-01 16:32:47 +00:00
|
|
|
B = torch.randint(
|
|
|
|
-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda"
|
|
|
|
).to(torch.int8)
|
2022-08-01 10:31:48 +00:00
|
|
|
out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
|
2022-08-01 16:32:47 +00:00
|
|
|
iout = torch.empty(
|
|
|
|
A.shape[2], B.shape[2], dtype=torch.int32, device=A.device
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
out = F.igemm(A, B, out=iout)
|
|
|
|
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(out.float(), out2)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
n = 2
|
2022-08-01 10:31:48 +00:00
|
|
|
seq_dim = torch.randint(32, 512, size=(n,)).tolist()
|
|
|
|
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
|
|
|
|
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
|
2022-07-22 21:41:05 +00:00
|
|
|
transpose = [False, True]
|
2022-08-01 10:31:48 +00:00
|
|
|
values = list(product(seq_dim, hidden_dim, batch_dim, transpose))
|
|
|
|
names = [
|
2022-10-27 11:14:13 +00:00
|
|
|
"seq_dim={}_hidden_dim={}_batch_dim={}_transpose{}".format(*vals)
|
2022-08-01 10:31:48 +00:00
|
|
|
for vals in values
|
|
|
|
]
|
|
|
|
|
|
|
|
|
2022-08-01 16:32:47 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"seq_dim, hidden_dim, batch_dim, transpose", values, ids=names
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
|
|
|
|
def min_max(x):
|
|
|
|
maxA = torch.amax(x, dim=2, keepdim=True)
|
|
|
|
minA = torch.amin(x, dim=2, keepdim=True)
|
2022-08-01 10:31:48 +00:00
|
|
|
scale = (maxA - minA) / 2.0
|
|
|
|
return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
seq_dim = seq_dim - (seq_dim % 16)
|
|
|
|
hidden_dim = hidden_dim - (hidden_dim % 16)
|
|
|
|
batch_dim = batch_dim - (batch_dim % 2)
|
|
|
|
errs = []
|
|
|
|
relerrs = []
|
|
|
|
errs2 = []
|
|
|
|
relerrs2 = []
|
|
|
|
for i in range(k):
|
2022-08-01 16:32:47 +00:00
|
|
|
A = torch.normal(
|
|
|
|
0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
if transpose:
|
2022-08-01 10:31:48 +00:00
|
|
|
B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
|
2022-07-22 21:41:05 +00:00
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda")
|
2022-07-22 21:41:05 +00:00
|
|
|
Ac, minA, scale = min_max(A)
|
|
|
|
if transpose:
|
|
|
|
maxB, Bc = quant_multi(B, dim=(1 if transpose else 0))
|
|
|
|
out = F.igemm(Ac, Bc.t())
|
2022-08-01 10:31:48 +00:00
|
|
|
out2 = torch.matmul(A, B.t())
|
|
|
|
offset = B.t().sum(0) * (minA + scale)
|
2022-07-22 21:41:05 +00:00
|
|
|
out = out.float()
|
2022-08-01 10:31:48 +00:00
|
|
|
out = (out * maxB.t() * scale / (127 * 127)) + offset
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
maxA, Ac = quant_multi(A, dim=2)
|
|
|
|
out3 = F.igemm(Ac, Bc.t())
|
|
|
|
out3 = mm_dequant(maxA, maxB.t(), out3)
|
|
|
|
else:
|
|
|
|
maxB, Bc = quant_multi(B, dim=0)
|
2022-08-01 10:31:48 +00:00
|
|
|
offset = B.sum(0) * (minA + scale)
|
2022-07-22 21:41:05 +00:00
|
|
|
out = F.igemm(Ac, Bc)
|
2022-08-01 10:31:48 +00:00
|
|
|
out2 = torch.matmul(A, B)
|
2022-07-22 21:41:05 +00:00
|
|
|
out = out.float()
|
2022-08-01 10:31:48 +00:00
|
|
|
out = (out * maxB * scale / (127 * 127)) + offset
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
maxA, Ac = quant_multi(A, dim=2)
|
|
|
|
out3 = F.igemm(Ac, Bc)
|
|
|
|
out3 = mm_dequant(maxA, maxB, out3)
|
|
|
|
|
|
|
|
std = out2.std()
|
|
|
|
out2 /= std
|
|
|
|
out /= std
|
|
|
|
out3 /= std
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
err = torch.abs(out - out2)
|
|
|
|
relerr = err / (torch.abs(out2) + 1e-7)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
err2 = torch.abs(out3 - out2)
|
|
|
|
relerr2 = err2 / (torch.abs(out2) + 1e-7)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
errs.append(err.mean().item())
|
|
|
|
relerrs.append(relerr.mean().item())
|
|
|
|
errs2.append(err2.mean().item())
|
|
|
|
relerrs2.append(relerr2.mean().item())
|
2022-08-01 10:31:48 +00:00
|
|
|
# print(mean(errs))
|
|
|
|
# print(mean(relerrs))
|
|
|
|
# print(mean(errs2))
|
|
|
|
# print(mean(relerrs2))
|
2022-07-22 21:41:05 +00:00
|
|
|
assert mean(errs) < 0.015
|
|
|
|
assert mean(relerrs) < 0.3
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
n = 2
|
2022-08-01 10:31:48 +00:00
|
|
|
dim1 = torch.randint(1, 64, size=(n,)).tolist()
|
|
|
|
dim2 = torch.randint(32, 128, size=(n,)).tolist()
|
|
|
|
dim3 = torch.randint(32, 256, size=(n,)).tolist()
|
|
|
|
dim4 = torch.randint(32, 256, size=(n,)).tolist()
|
2022-07-22 21:41:05 +00:00
|
|
|
transpose = [(False, False), (True, False), (False, True), (True, True)]
|
2022-08-01 10:31:48 +00:00
|
|
|
values = list(product(dim1, dim2, dim3, dim4, transpose))
|
|
|
|
names = [
|
2022-10-27 11:14:13 +00:00
|
|
|
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_transpose_{}".format(*vals)
|
2022-08-01 16:32:47 +00:00
|
|
|
for vals in values
|
2022-08-01 10:31:48 +00:00
|
|
|
]
|
|
|
|
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names)
|
|
|
|
def test_ibmm(dim1, dim2, dim3, dim4, transpose):
|
|
|
|
dim2 = dim2 - (dim2 % 16)
|
|
|
|
dim3 = dim3 - (dim3 % 16)
|
|
|
|
dim4 = dim4 - (dim4 % 16)
|
|
|
|
for i in range(k):
|
|
|
|
shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
|
|
|
|
shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4)
|
2022-08-01 10:31:48 +00:00
|
|
|
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
|
|
|
|
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
if not transpose[0] and not transpose[1]:
|
|
|
|
out2 = torch.bmm(A.float(), B.float())
|
|
|
|
out = F.igemm(A, B)
|
|
|
|
elif not transpose[0] and transpose[1]:
|
|
|
|
out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float())
|
|
|
|
out = F.igemm(A, B.permute([0, 2, 1]))
|
|
|
|
elif transpose[0] and not transpose[1]:
|
|
|
|
out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
|
|
|
|
out = F.igemm(A.permute([0, 2, 1]), B)
|
|
|
|
elif transpose[0] and transpose[1]:
|
2022-08-01 16:32:47 +00:00
|
|
|
out2 = torch.bmm(
|
|
|
|
A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(out.float(), out2.float())
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
n = 1
|
2022-08-01 10:31:48 +00:00
|
|
|
dim1 = torch.randint(1, 64, size=(n,)).tolist()
|
|
|
|
dim2 = torch.randint(32, 128, size=(n,)).tolist()
|
|
|
|
dim3 = torch.randint(32, 256, size=(n,)).tolist()
|
|
|
|
values = list(product(dim1, dim2, dim3))
|
2022-10-27 11:14:13 +00:00
|
|
|
names = ["dim1_{}_dim2_{}_dim3_{}".format(*vals) for vals in values]
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names)
|
|
|
|
def test_vector_quant(dim1, dim2, dim3):
|
|
|
|
dim2 = dim2 - (dim2 % 16)
|
|
|
|
dim3 = dim3 - (dim3 % 16)
|
|
|
|
for i in range(k):
|
2022-08-01 10:31:48 +00:00
|
|
|
A = torch.randn(size=(dim2, dim3), device="cuda")
|
2022-07-22 21:41:05 +00:00
|
|
|
qA, SA = F.vectorwise_quant(A, dim=0)
|
|
|
|
A1 = F.vectorwise_dequant(qA, SA)
|
2022-10-24 18:54:25 +00:00
|
|
|
n = A1.numel()
|
|
|
|
assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n*0.002))
|
|
|
|
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
|
|
|
|
n = 2
|
2022-08-01 10:31:48 +00:00
|
|
|
dim1 = torch.randint(2, 256, size=(n,)).tolist()
|
|
|
|
dim2 = torch.randint(2, 256, size=(n,)).tolist()
|
|
|
|
dim3 = torch.randint(2, 256, size=(n,)).tolist()
|
|
|
|
# dim1, dim2 = (256,), (256,)
|
2022-07-22 21:41:05 +00:00
|
|
|
dtype = [torch.int8, torch.int32]
|
2022-08-01 10:31:48 +00:00
|
|
|
a_order = ["row"]
|
|
|
|
out_order = ["col", "row", "col32"]
|
2022-07-22 21:41:05 +00:00
|
|
|
transpose = [False]
|
|
|
|
dims = [2, 3]
|
2022-10-24 18:54:25 +00:00
|
|
|
values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-10-27 11:14:13 +00:00
|
|
|
names = ["dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_transpose_{}".format(*vals)for vals in values]
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-10-24 18:54:25 +00:00
|
|
|
@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",values,ids=names)
|
|
|
|
def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
|
2022-08-01 10:31:48 +00:00
|
|
|
if dims == 3 and out_order != "col32":
|
|
|
|
return
|
|
|
|
if dtype == torch.int32 and out_order != "col32":
|
|
|
|
return
|
2022-07-22 21:41:05 +00:00
|
|
|
func = F.get_transform_func(dtype, orderA, orderOut, transpose)
|
|
|
|
|
|
|
|
if dims == 2:
|
2022-08-01 10:31:48 +00:00
|
|
|
A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
|
2022-07-22 21:41:05 +00:00
|
|
|
elif dims == 3:
|
2022-08-01 16:32:47 +00:00
|
|
|
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
|
|
|
|
dtype
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
out, S = F.nvidia_transform(A, to_order=orderOut)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
if orderOut == "row":
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(A.flatten(), out.flatten())
|
2022-08-01 10:31:48 +00:00
|
|
|
elif orderOut == "col":
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(A.t().flatten(), out.flatten())
|
2022-08-01 10:31:48 +00:00
|
|
|
elif orderOut == "col32":
|
2022-07-22 21:41:05 +00:00
|
|
|
if dims == 2:
|
2022-08-01 10:31:48 +00:00
|
|
|
n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
|
2022-07-22 21:41:05 +00:00
|
|
|
elif dims == 3:
|
2022-08-01 16:32:47 +00:00
|
|
|
n = (
|
|
|
|
A.shape[0]
|
|
|
|
* A.shape[1]
|
|
|
|
* (A.shape[2] + (32 - (A.shape[2] % 32)))
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
assert out.numel() == n
|
2022-08-01 10:31:48 +00:00
|
|
|
elif orderOut == "col_turing":
|
2022-07-22 21:41:05 +00:00
|
|
|
# 32 col 8 row tiles
|
2022-08-01 10:31:48 +00:00
|
|
|
n = (A.shape[0] + (8 - A.shape[0] % 8)) * (
|
|
|
|
A.shape[1] + (32 - (A.shape[1] % 32))
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
assert out.numel() == n
|
|
|
|
total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
|
|
|
|
for row in range(A.shape[0]):
|
|
|
|
for col in range(A.shape[1]):
|
2022-08-01 10:31:48 +00:00
|
|
|
i = row * A.shape[1]
|
2022-07-22 21:41:05 +00:00
|
|
|
j = col
|
|
|
|
|
|
|
|
coltile = (col // 32) + (1 if col % 32 != 0 else 0)
|
2022-08-01 16:32:47 +00:00
|
|
|
rowtile = (
|
|
|
|
(row // 8) + (1 if row % 8 != 0 else 0)
|
|
|
|
) * total_coltile
|
2022-08-01 10:31:48 +00:00
|
|
|
offset = 32 * 8 * (rowtile + coltile)
|
2022-07-22 21:41:05 +00:00
|
|
|
col2 = col % 32
|
2022-08-01 10:31:48 +00:00
|
|
|
row2 = (row % 8) * 32
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
assert A.flatten()[i + j] == A[row, col]
|
|
|
|
# assert A.flatten()[i+j] == out.flatten()[row2+col2]
|
2023-05-07 20:34:03 +00:00
|
|
|
# torch.testing.assert_close(A.flatten()[i+j], A[row, col])
|
|
|
|
# torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
if orderOut == "col32":
|
2022-08-01 16:32:47 +00:00
|
|
|
out2, S = F.nvidia_transform(
|
|
|
|
out, from_order=orderOut, to_order="row", state=S
|
|
|
|
)
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(A, out2)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
|
|
|
|
n = 1
|
2022-08-01 10:31:48 +00:00
|
|
|
dim1 = torch.randint(1, 256, size=(n,)).tolist()
|
|
|
|
dim2 = torch.randint(32, 512, size=(n,)).tolist()
|
|
|
|
dim3 = torch.randint(32, 1024, size=(n,)).tolist()
|
|
|
|
dim4 = torch.randint(32, 1024, size=(n,)).tolist()
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
# dim1 = [2]
|
|
|
|
# dim2 = [2]
|
|
|
|
# dim3 = [2]
|
|
|
|
# dim4 = [2]
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
dims = (2, 3)
|
2022-07-22 21:41:05 +00:00
|
|
|
ldb = [0]
|
2022-08-01 10:31:48 +00:00
|
|
|
# ldb = list(range(256, 1*1024, 256))
|
|
|
|
values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
|
|
|
|
names = [
|
2022-10-27 11:14:13 +00:00
|
|
|
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}_ldb_{}".format(*vals)
|
2022-08-01 10:31:48 +00:00
|
|
|
for vals in values
|
|
|
|
]
|
|
|
|
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names)
|
|
|
|
def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
|
|
|
|
for i in range(k):
|
|
|
|
if dims == 2:
|
2022-08-01 10:31:48 +00:00
|
|
|
A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(
|
|
|
|
torch.int8
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
elif dims == 3:
|
2022-08-01 16:32:47 +00:00
|
|
|
A = torch.randint(
|
|
|
|
-128, 127, size=(dim1, dim2, dim3), device="cuda"
|
|
|
|
).to(torch.int8)
|
|
|
|
B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(
|
|
|
|
torch.int8
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
C1 = torch.matmul(A.float(), B.t().float())
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
A2, SA = F.transform(A, "col32")
|
|
|
|
B2, SB = F.transform(B, "col_turing")
|
2022-07-22 21:41:05 +00:00
|
|
|
C2, SC = F.igemmlt(A2, B2, SA, SB)
|
2022-08-01 10:31:48 +00:00
|
|
|
C3, S = F.nvidia_transform(C2, "row", state=SC)
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(C1, C3.float())
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
# transpose
|
2022-08-01 16:32:47 +00:00
|
|
|
B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
|
|
|
|
torch.int8
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
C1 = torch.matmul(A.float(), B.float())
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
B2t, SBt = F.transform(B, "col_turing", transpose=True)
|
2022-07-22 21:41:05 +00:00
|
|
|
C2, SC = F.igemmlt(A2, B2t, SA, SBt)
|
2022-08-01 10:31:48 +00:00
|
|
|
C3, S = F.nvidia_transform(C2, "row", state=SC)
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(C1, C3.float())
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
dim1 = [32]
|
|
|
|
dim2 = [32]
|
|
|
|
dim3 = [32]
|
|
|
|
dim4 = [32]
|
|
|
|
|
|
|
|
dims = (2,)
|
2022-08-01 10:31:48 +00:00
|
|
|
# ldb = list(range(256, 1*1024, 256))
|
|
|
|
values = list(product(dim1, dim2, dim3, dim4, dims))
|
|
|
|
names = [
|
2022-10-27 11:14:13 +00:00
|
|
|
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}".format(*vals)
|
2022-08-01 16:32:47 +00:00
|
|
|
for vals in values
|
2022-08-01 10:31:48 +00:00
|
|
|
]
|
|
|
|
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names)
|
|
|
|
def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
|
|
|
|
formatB = F.get_special_format_str()
|
|
|
|
for i in range(k):
|
|
|
|
if dims == 2:
|
2022-08-01 10:31:48 +00:00
|
|
|
A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
|
2022-07-22 21:41:05 +00:00
|
|
|
elif dims == 3:
|
2022-08-01 16:32:47 +00:00
|
|
|
A = torch.normal(
|
|
|
|
0, 0.5, size=(dim1, dim2, dim3), device="cuda"
|
|
|
|
).half()
|
2022-08-01 10:31:48 +00:00
|
|
|
B = torch.randn((dim4, dim3), device="cuda").half()
|
2022-07-22 21:41:05 +00:00
|
|
|
torch.nn.init.xavier_uniform_(B)
|
|
|
|
C1 = torch.matmul(A, B.t())
|
|
|
|
C2 = bnb.matmul(A, B.t())
|
|
|
|
|
|
|
|
A = A.view(-1, A.shape[-1])
|
|
|
|
|
|
|
|
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
|
|
|
|
CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
|
2022-08-01 10:31:48 +00:00
|
|
|
C32A, SA = F.transform(CA, "col32")
|
2022-07-22 21:41:05 +00:00
|
|
|
CxB, SB = F.transform(CB, to_order=formatB)
|
|
|
|
out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB)
|
|
|
|
output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
# print('')
|
|
|
|
# print(output.flatten()[:10])
|
|
|
|
# print(C1.flatten()[:10])
|
|
|
|
# print(C2.flatten()[:10])
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2023-05-07 20:34:03 +00:00
|
|
|
# torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
# transpose
|
2022-08-01 10:31:48 +00:00
|
|
|
# B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
|
|
|
|
# C1 = torch.matmul(A.float(), B.float())
|
|
|
|
|
|
|
|
# B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
|
|
|
|
# C2, SC = F.igemmlt(A2, B2t, SA, SBt)
|
|
|
|
# C3, S = F.transform(C2, 'row', state=SC)
|
2023-05-07 20:34:03 +00:00
|
|
|
# torch.testing.assert_close(C1, C3.float())
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
|
|
|
|
batch_size = 2
|
|
|
|
seqdim = 512
|
2022-08-01 10:31:48 +00:00
|
|
|
# values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
|
|
|
|
values = [
|
|
|
|
(batch_size, seqdim, 4 * 1024, 3 * 4 * 1024),
|
|
|
|
(batch_size, seqdim, 5120, 3 * 5120),
|
|
|
|
(batch_size, seqdim, 12 * 1024, 4 * 12 * 1024),
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
# values = list(product(batch, seq, model, hidden))
|
2022-08-01 16:32:47 +00:00
|
|
|
names = [
|
2022-10-27 11:14:13 +00:00
|
|
|
"batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values
|
2022-08-01 16:32:47 +00:00
|
|
|
]
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
|
|
|
|
def test_bench_8bit_training(batch, seq, model, hidden):
|
|
|
|
formatB = F.get_special_format_str()
|
2022-08-01 10:31:48 +00:00
|
|
|
A = torch.randn(batch, seq, model, device="cuda").half()
|
|
|
|
grad = torch.randn(batch, seq, model, device="cuda").half()
|
|
|
|
w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half()
|
|
|
|
w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half()
|
|
|
|
print("")
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
# torch.cuda.synchronize()
|
2022-07-22 21:41:05 +00:00
|
|
|
## warmup
|
2022-08-01 10:31:48 +00:00
|
|
|
# for i in range(100):
|
2022-07-22 21:41:05 +00:00
|
|
|
# torch.matmul(A, w1.t())
|
2022-08-01 10:31:48 +00:00
|
|
|
# torch.cuda.synchronize()
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
dtype = torch.int8
|
|
|
|
A = A.view(-1, A.shape[-1]).contiguous()
|
|
|
|
grad = grad.view(-1, grad.shape[-1]).contiguous()
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
t0 = time.time()
|
|
|
|
for i in range(k):
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
out1 = torch.matmul(A, w1.t()) # fc1
|
|
|
|
# out2 = torch.matmul(out1, w2.t())# fc2
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
# d1 = torch.matmul(grad, w2) # delta1
|
|
|
|
# d2 = torch.matmul(d1, w1) # delta2
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
# grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
|
|
|
|
# grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
t16 = time.time() - t0
|
|
|
|
print(t16)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
# torch.cuda.empty_cache()
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
# Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
|
|
|
|
# Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
# CTw1, Sw1 = F.transform2(Cw1, formatB)
|
|
|
|
# CTw2, Sw2 = F.transform2(Cw2, formatB)
|
|
|
|
# CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
|
|
|
|
# CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
# CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
|
|
|
|
# C32A, SA = F.transform2(CA, 'col32')
|
2022-07-22 21:41:05 +00:00
|
|
|
## fc1
|
2022-08-01 10:31:48 +00:00
|
|
|
# out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
|
2022-07-22 21:41:05 +00:00
|
|
|
##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)
|
|
|
|
|
|
|
|
## fc2
|
2022-08-01 10:31:48 +00:00
|
|
|
# Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
|
|
|
|
# C32out1, Sout1 = F.transform2(Cout1, 'col32')
|
|
|
|
# out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
|
2022-07-22 21:41:05 +00:00
|
|
|
##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)
|
|
|
|
|
|
|
|
## delta1
|
2022-08-01 10:31:48 +00:00
|
|
|
# Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
|
|
|
|
# C32grad, Sgrad = F.transform2(Cgrad, 'col32')
|
2022-07-22 21:41:05 +00:00
|
|
|
##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
|
|
|
|
##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)
|
|
|
|
|
|
|
|
## delta2
|
2022-08-01 10:31:48 +00:00
|
|
|
# Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
|
|
|
|
# C32d1, Sd1 = F.transform2(Cd1, 'col32')
|
2022-07-22 21:41:05 +00:00
|
|
|
##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
|
|
|
|
##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)
|
|
|
|
|
|
|
|
## grad1
|
2022-08-01 10:31:48 +00:00
|
|
|
# C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
|
|
|
|
# CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
|
2022-07-22 21:41:05 +00:00
|
|
|
##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
|
|
|
|
##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)
|
|
|
|
|
|
|
|
## grad2
|
2022-08-01 10:31:48 +00:00
|
|
|
# C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
|
|
|
|
# CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
|
2022-07-22 21:41:05 +00:00
|
|
|
##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
|
|
|
|
##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
# Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
# Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
|
|
|
|
# Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
# CTw1, Sw1 = F.transform2(Cw1, formatB)
|
|
|
|
# CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
|
|
|
|
# CTw2, Sw2 = F.transform2(Cw2, formatB)
|
|
|
|
# CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
|
|
|
|
# torch.cuda.synchronize()
|
|
|
|
# t0 = time.time()
|
|
|
|
# for i in range(k):
|
2022-07-22 21:41:05 +00:00
|
|
|
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
|
|
|
|
# #CTw1, Sw1 = F.transform2(Cw1, formatB)
|
|
|
|
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
|
|
|
|
# #CTw1, Sw1 = F.transform2(Cw1, formatB)
|
|
|
|
|
|
|
|
# #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=3.5)
|
|
|
|
# CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
|
|
|
|
# #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
|
|
|
|
# #CTw2, Sw2 = F.transform2(Cw2, formatB)
|
|
|
|
# #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
|
|
|
|
|
|
|
|
# C32A, SA = F.transform2(CA, 'col32')
|
|
|
|
|
|
|
|
# # fc1
|
|
|
|
# out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
|
|
|
|
# #out1dn = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
|
|
|
|
|
|
|
|
# #print(coo_tensor.nnz)
|
|
|
|
# #out1sp = F.spmm_coo(coo_tensor, w1.t())
|
|
|
|
# #print(w1.t().shape)
|
|
|
|
# #out1 = out1dn + out1sp
|
|
|
|
|
|
|
|
# # fc2
|
|
|
|
# Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
|
|
|
|
# C32out1, Sout1 = F.transform2(Cout1, 'col32')
|
|
|
|
# out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
|
|
|
|
# #out2 = F.mm_dequant(out2_32, Sout2_32, statsout1, statsw2)
|
|
|
|
|
|
|
|
# # delta1
|
|
|
|
# Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
|
|
|
|
# C32grad, Sgrad = F.transform2(Cgrad, 'col32')
|
|
|
|
# d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
|
|
|
|
# #d1 = F.mm_dequant(d1_32, Sd1_32, statsgrad, statsw2t)
|
|
|
|
|
|
|
|
# # delta2
|
|
|
|
# Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
|
|
|
|
# C32d1, Sd1 = F.transform2(Cd1, 'col32')
|
|
|
|
# d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
|
|
|
|
# #d2 = F.mm_dequant(d2_32, Sd2_32, statsd1, statsw1t)
|
|
|
|
|
|
|
|
# # grad1
|
|
|
|
# #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
|
|
|
|
# #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
|
|
|
|
# #grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
|
|
|
|
# #grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1t, statsgradt)
|
|
|
|
|
|
|
|
# ## grad2
|
|
|
|
# #C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
|
|
|
|
# #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
|
|
|
|
# #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
|
|
|
|
# #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
# torch.cuda.synchronize()
|
|
|
|
# t8 = time.time() - t0
|
|
|
|
# print(t8)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
|
|
|
|
n = 2
|
2022-08-01 10:31:48 +00:00
|
|
|
dim1 = torch.randint(64, 256, size=(n,)).tolist()
|
|
|
|
dim4 = torch.randint(64, 1024, size=(n,)).tolist()
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-10-24 18:54:25 +00:00
|
|
|
#dim1 = [2*1024]
|
|
|
|
#dim4 = [2*1024]
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-16 18:12:09 +00:00
|
|
|
#dim1 = [4]
|
|
|
|
#dim4 = [4]
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
dims = (2,)
|
2022-08-01 10:31:48 +00:00
|
|
|
formatB = ["col_turing", "col_ampere"]
|
2022-08-16 17:56:17 +00:00
|
|
|
has_bias = [True, False]
|
|
|
|
values = list(product(dim1, dim4, dims, formatB, has_bias))
|
2022-10-27 11:14:13 +00:00
|
|
|
names = ["dim1_{}_dim4_{}_dims_{}_formatB_{}_has_bias_{}".format(*vals) for vals in values]
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
|
2022-08-16 17:56:17 +00:00
|
|
|
@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names)
|
|
|
|
def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
|
2022-07-22 21:41:05 +00:00
|
|
|
inner = torch.randint(1, 128, size=(1,)).item()
|
2022-08-16 17:56:17 +00:00
|
|
|
bias = None
|
|
|
|
if has_bias: bias = torch.randn(dim4, device='cuda', dtype=torch.float16)
|
2022-07-22 21:41:05 +00:00
|
|
|
formatB = F.get_special_format_str()
|
2022-08-16 18:12:09 +00:00
|
|
|
for i in range(1):
|
2022-08-01 10:31:48 +00:00
|
|
|
A = torch.randn(dim1, inner, device="cuda")
|
|
|
|
B = torch.randn(dim4, inner, device="cuda")
|
2022-07-22 21:41:05 +00:00
|
|
|
C1 = torch.matmul(A.half(), B.t().half())
|
2022-08-16 17:56:17 +00:00
|
|
|
if has_bias: C1 += bias
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
A1, maxA = F.vectorwise_quant(A, dim=1)
|
|
|
|
B1, maxB = F.vectorwise_quant(B, dim=1)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
A2, SA = F.nvidia_transform(A1, "col32")
|
2022-07-22 21:41:05 +00:00
|
|
|
B2, SB = F.nvidia_transform(B1, formatB)
|
|
|
|
C2, SC = F.igemmlt(A2, B2, SA, SB)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
C3, S = F.nvidia_transform(C2, "row", state=SC)
|
2022-07-22 21:41:05 +00:00
|
|
|
C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
|
2022-08-16 17:56:17 +00:00
|
|
|
if has_bias: C4 += bias
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-10-24 18:54:25 +00:00
|
|
|
# TODO: is something wrong here? If so, the problem goes deeper
|
|
|
|
#n = C1.numel()
|
|
|
|
#p = 0.06
|
|
|
|
std = C1.std(0).view(1, -1)
|
|
|
|
C1 /= std
|
|
|
|
C4 /= std
|
|
|
|
#assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06))
|
2022-08-16 18:12:09 +00:00
|
|
|
#assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-16 17:56:17 +00:00
|
|
|
C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias)
|
2023-05-07 20:34:03 +00:00
|
|
|
#torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1)
|
2022-10-24 18:54:25 +00:00
|
|
|
n = C5.numel()
|
|
|
|
assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01*n))
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
|
|
|
|
n = 2
|
2022-08-01 10:31:48 +00:00
|
|
|
dim1 = [1 * 1024]
|
|
|
|
dim2 = [1 * 1024]
|
|
|
|
# dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
|
|
|
|
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
dims = (2,)
|
2022-08-01 10:31:48 +00:00
|
|
|
# ldb = list(range(256, 1*1024, 256))
|
|
|
|
values = list(product(dim1, dim2, dims))
|
2022-10-27 11:14:13 +00:00
|
|
|
names = ["dim1_{}_dim2_{}_dims_{}".format(*vals) for vals in values]
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names)
|
|
|
|
def test_colrow_absmax(dim1, dim2, dims):
|
|
|
|
for i in range(k):
|
|
|
|
threshold = 3.0
|
2022-08-01 10:31:48 +00:00
|
|
|
A = torch.randn(dim1, dim2, device="cuda").half()
|
2022-07-22 21:41:05 +00:00
|
|
|
A_truncated = A.clone()
|
|
|
|
A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0
|
|
|
|
if dims == 2:
|
|
|
|
row_stats1, _ = torch.abs(A.float()).max(1)
|
|
|
|
col_stats1, _ = torch.abs(A.float()).max(0)
|
|
|
|
row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1)
|
|
|
|
col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0)
|
|
|
|
else:
|
|
|
|
assert False
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
|
|
|
|
A, threshold=threshold
|
|
|
|
)
|
|
|
|
|
|
|
|
A_blocked = einops.rearrange(
|
|
|
|
torch.abs(A),
|
|
|
|
"(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size",
|
|
|
|
row_tiles=16,
|
|
|
|
block_size=64 * 4,
|
|
|
|
)
|
|
|
|
nnz_rows1_counts = (torch.abs(A_blocked) >= threshold).sum(3).flatten()
|
|
|
|
nnz_block_ptr1 = torch.zeros(
|
|
|
|
nnz_rows1_counts.shape[0] + 1,
|
|
|
|
dtype=nnz_rows1_counts.dtype,
|
|
|
|
device=nnz_rows1_counts.device,
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)
|
|
|
|
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(col_stats1_trunc, col_stats2)
|
|
|
|
torch.testing.assert_close(row_stats1_trunc, row_stats2)
|
|
|
|
torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 16:32:47 +00:00
|
|
|
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
|
|
|
|
A, threshold=0.0
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(col_stats1, col_stats2)
|
|
|
|
torch.testing.assert_close(row_stats1, row_stats2)
|
2022-07-22 21:41:05 +00:00
|
|
|
assert nnz_block_ptr2 is None
|
|
|
|
|
|
|
|
|
|
|
|
n = 2
|
2022-08-01 10:31:48 +00:00
|
|
|
# dim1 = [8*1024]
|
|
|
|
# dim2 = [4*1024]
|
|
|
|
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
|
|
|
|
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
|
|
|
|
|
|
|
|
values = list(product(dim1, dim2))
|
2022-10-27 11:14:13 +00:00
|
|
|
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
|
|
|
|
def test_double_quant(dim1, dim2):
|
|
|
|
for i in range(k):
|
2022-08-01 10:31:48 +00:00
|
|
|
A = torch.randn(dim1, dim2, device="cuda").half()
|
2022-07-22 21:41:05 +00:00
|
|
|
out_col1, Scol = F.vectorwise_quant(A, dim=0)
|
|
|
|
out_row1, Srow = F.vectorwise_quant(A, dim=1)
|
|
|
|
|
|
|
|
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
|
|
|
|
|
|
|
|
# max difference is 1 due to rounding differences
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(CA, out_row1, atol=1, rtol=0)
|
|
|
|
torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
n = CAt.numel()
|
2022-08-01 16:32:47 +00:00
|
|
|
num_not_close_rows = (
|
|
|
|
(torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
|
|
|
|
)
|
|
|
|
num_not_close_cols = (
|
|
|
|
(torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
# allow for 1:500 error due to rounding differences
|
2022-08-01 10:31:48 +00:00
|
|
|
min_error = 1 / 500
|
|
|
|
if num_not_close_cols > (min_error * n):
|
|
|
|
print(
|
|
|
|
f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}"
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
assert False
|
2022-08-01 10:31:48 +00:00
|
|
|
if num_not_close_rows > (min_error * n):
|
|
|
|
print(
|
|
|
|
f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}"
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
assert False
|
|
|
|
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(Srow.flatten().float(), statsA)
|
|
|
|
torch.testing.assert_close(Scol.flatten().float(), statsAt)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
|
|
|
|
n = 4
|
2022-08-01 10:31:48 +00:00
|
|
|
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
|
|
|
|
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
|
|
|
|
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
values = list(zip(dim1, dim4, inner))
|
2022-10-27 11:14:13 +00:00
|
|
|
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
|
|
|
|
def test_integrated_igemmlt(dim1, dim4, inner):
|
|
|
|
for i in range(k):
|
2022-08-01 10:31:48 +00:00
|
|
|
A = torch.randn(dim1, inner, device="cuda").half()
|
|
|
|
B = torch.randn(dim4, inner, device="cuda").half()
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
out1 = torch.matmul(A.half(), B.t().half())
|
|
|
|
|
|
|
|
C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
|
|
|
|
C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
|
|
|
|
A1, maxA = F.vectorwise_quant(A, dim=1)
|
|
|
|
B1, maxB = F.vectorwise_quant(B, dim=1)
|
|
|
|
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(maxA.flatten().float(), stats1a)
|
|
|
|
torch.testing.assert_close(maxB.flatten().float(), stats2a)
|
|
|
|
torch.testing.assert_close(C1a, A1, rtol=0, atol=1)
|
|
|
|
torch.testing.assert_close(C2a, B1, rtol=0, atol=1)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
A2, SA = F.nvidia_transform(C1a, "col32")
|
|
|
|
B2, SB = F.nvidia_transform(C2a, "col_turing")
|
2022-07-22 21:41:05 +00:00
|
|
|
outC32, SC = F.igemmlt(A2, B2, SA, SB)
|
|
|
|
out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
A2, SA = F.nvidia_transform(A1, "col32")
|
|
|
|
B2, SB = F.nvidia_transform(B1, "col_turing")
|
2022-07-22 21:41:05 +00:00
|
|
|
C2, SC = F.igemmlt(A2, B2, SA, SB)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
C3, S = F.nvidia_transform(C2, "row", state=SC)
|
2022-07-22 21:41:05 +00:00
|
|
|
out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
err1 = torch.abs(out1 - out2).mean().item()
|
|
|
|
err2 = torch.abs(out1 - out3).mean().item()
|
2022-10-24 18:54:25 +00:00
|
|
|
assert err2 <= err1 * 1.025
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
|
|
|
|
n = 6
|
2022-08-01 10:31:48 +00:00
|
|
|
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
|
|
|
|
dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
|
|
|
|
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
values = list(zip(dim1, dim4, inner))
|
2022-10-27 11:14:13 +00:00
|
|
|
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
|
2022-07-26 00:27:57 +00:00
|
|
|
@pytest.mark.skip("Row scale has some bugs for ampere")
|
2022-07-22 21:41:05 +00:00
|
|
|
def test_igemmlt_row_scale(dim1, dim4, inner):
|
|
|
|
formatB = F.get_special_format_str()
|
|
|
|
err1, err2, err3 = [], [], []
|
|
|
|
relerr1, relerr2 = [], []
|
|
|
|
scale = 1
|
|
|
|
for i in range(k):
|
2022-08-01 10:31:48 +00:00
|
|
|
A = torch.randn(dim1, inner, device="cuda").half()
|
|
|
|
B = torch.randn(dim4, inner, device="cuda").half()
|
2022-07-22 21:41:05 +00:00
|
|
|
torch.nn.init.xavier_uniform_(B)
|
|
|
|
C1 = torch.matmul(A, B.t())
|
|
|
|
|
|
|
|
out1 = torch.matmul(A.half(), B.t().half())
|
|
|
|
|
|
|
|
C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
|
2022-08-01 10:31:48 +00:00
|
|
|
CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
|
|
|
|
A2, SA = F.nvidia_transform(C1a, "col32")
|
2022-07-22 21:41:05 +00:00
|
|
|
B2, SB = F.nvidia_transform(CB, formatB)
|
|
|
|
A1, maxA = F.vectorwise_quant(A, dim=1)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
c = 10.0 * inner * scale
|
|
|
|
row_scale = torch.ones_like(maxA) / c
|
2022-08-01 16:32:47 +00:00
|
|
|
outC32, SC = F.igemmlt(
|
|
|
|
A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
|
|
|
|
)
|
2022-08-01 10:31:48 +00:00
|
|
|
C3, S = F.nvidia_transform(outC32, "row", state=SC)
|
2022-07-22 21:41:05 +00:00
|
|
|
maxval = torch.abs(C3).max()
|
|
|
|
if maxval == 127:
|
|
|
|
scale = 1.5
|
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
scale = maxval / 120
|
|
|
|
out3 = C3 * maxA * absmaxB * c / (127 * 127)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
C4 = torch.matmul(C1a.float(), CB.float().t())
|
|
|
|
|
|
|
|
C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
|
|
|
|
B2, SB = F.nvidia_transform(C2a, formatB)
|
|
|
|
outC32, SC = F.igemmlt(A2, B2, SA, SB)
|
|
|
|
out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector")
|
|
|
|
CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear")
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
C = torch.matmul(CA.float(), CB.t().float())
|
2022-08-01 10:31:48 +00:00
|
|
|
out4 = C * SA * SB / (127 * 127)
|
|
|
|
# out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
# print('='*80)
|
|
|
|
# print(out1)
|
|
|
|
# print(out2)
|
|
|
|
# print(out3)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
# print(out1)
|
|
|
|
# print(out2)
|
|
|
|
# print(out3)
|
|
|
|
err1.append(torch.abs(out1 - out2).mean().item())
|
|
|
|
err2.append(torch.abs(out1 - out3).mean().item())
|
|
|
|
err3.append(torch.abs(out1 - out4).mean().item())
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
# assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
|
|
|
|
print("")
|
|
|
|
print(sum(err1) / len(err1))
|
|
|
|
print(sum(err2) / len(err2))
|
|
|
|
print(sum(err3) / len(err3))
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
|
|
|
|
dim1 = [1024, 2048]
|
2022-08-01 10:31:48 +00:00
|
|
|
inner = [12288 * 4, 4096 * 4]
|
2022-07-22 21:41:05 +00:00
|
|
|
dim4 = [12288, 4096]
|
|
|
|
|
|
|
|
values = list(zip(dim1, dim4, inner))
|
2022-10-27 11:14:13 +00:00
|
|
|
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
|
2022-07-26 00:27:57 +00:00
|
|
|
@pytest.mark.skip("Row scale has some bugs for ampere")
|
2022-07-22 21:41:05 +00:00
|
|
|
def test_row_scale_bench(dim1, dim4, inner):
|
|
|
|
err1, err2, err3 = [], [], []
|
|
|
|
relerr1, relerr2 = [], []
|
|
|
|
scale = 1
|
2022-08-01 10:31:48 +00:00
|
|
|
A = torch.randn(dim1, inner, device="cuda").half()
|
|
|
|
B = torch.randn(dim4, inner, device="cuda").half()
|
2022-07-22 21:41:05 +00:00
|
|
|
torch.nn.init.xavier_uniform_(B)
|
|
|
|
# warmpup
|
|
|
|
for i in range(k):
|
|
|
|
C1 = torch.matmul(A, B.t())
|
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
t0 = time.time()
|
|
|
|
for i in range(k):
|
|
|
|
C1 = torch.matmul(A, B.t())
|
|
|
|
torch.cuda.synchronize()
|
2022-08-01 10:31:48 +00:00
|
|
|
print("16", time.time() - t0)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
|
2022-08-01 10:31:48 +00:00
|
|
|
CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
|
|
|
|
A2, SA = F.nvidia_transform(C1a, "col32")
|
2022-07-22 21:41:05 +00:00
|
|
|
B2, SB = F.nvidia_transform(CB, formatB)
|
|
|
|
A1, maxA = F.vectorwise_quant(A, dim=1)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
c = 10.0 * inner * scale
|
|
|
|
row_scale = maxA / c
|
2022-07-22 21:41:05 +00:00
|
|
|
torch.cuda.synchronize()
|
|
|
|
t0 = time.time()
|
|
|
|
for i in range(k):
|
2022-08-01 16:32:47 +00:00
|
|
|
outC32, SC = F.igemmlt(
|
|
|
|
A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
torch.cuda.synchronize()
|
2022-08-01 10:31:48 +00:00
|
|
|
print("row-wise", time.time() - t0)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
|
|
|
|
B2, SB = F.nvidia_transform(C2a, formatB)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
t0 = time.time()
|
|
|
|
for i in range(k):
|
|
|
|
outC32, SC = F.igemmlt(A2, B2, SA, SB)
|
|
|
|
torch.cuda.synchronize()
|
2022-08-01 10:31:48 +00:00
|
|
|
print("vector-wise", time.time() - t0)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
|
|
|
|
n = 2
|
2022-08-01 10:31:48 +00:00
|
|
|
dim1 = torch.randint(2, 1024, size=(n,)).tolist()
|
|
|
|
dim2 = torch.randint(2, 1024, size=(n,)).tolist()
|
|
|
|
# dim1 = [8*1024]
|
|
|
|
# dim2 = [4*1024]
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
dim3 = [0]
|
|
|
|
dtype = [torch.int8]
|
2022-08-01 10:31:48 +00:00
|
|
|
a_order = ["row"]
|
|
|
|
out_order = ["col32", "col_turing", "col_ampere"]
|
2022-07-22 21:41:05 +00:00
|
|
|
transpose = [False, True]
|
|
|
|
dims = [2]
|
2022-08-01 16:32:47 +00:00
|
|
|
values = list(
|
|
|
|
product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
|
|
|
|
)
|
2022-08-01 10:31:48 +00:00
|
|
|
names = [
|
2022-10-27 11:14:13 +00:00
|
|
|
"dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_{}".format(
|
2022-08-01 10:31:48 +00:00
|
|
|
*vals
|
|
|
|
)
|
|
|
|
for vals in values
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
2022-08-01 16:32:47 +00:00
|
|
|
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",
|
|
|
|
values,
|
|
|
|
ids=names,
|
2022-08-01 10:31:48 +00:00
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
|
|
|
|
for i in range(k):
|
|
|
|
if dims == 2:
|
2022-08-01 16:32:47 +00:00
|
|
|
A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(
|
|
|
|
dtype
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
elif dims == 3:
|
2022-08-01 16:32:47 +00:00
|
|
|
A = torch.randint(
|
|
|
|
10, 99, size=(dim1, dim2, dim3), device="cuda"
|
|
|
|
).to(dtype)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
A.view(-1)[-1] = -1
|
|
|
|
if transpose:
|
|
|
|
At = A.t().contiguous()
|
|
|
|
out1, S1 = F.nvidia_transform(At, to_order=orderOut)
|
|
|
|
else:
|
|
|
|
out1, S1 = F.nvidia_transform(A, to_order=orderOut)
|
|
|
|
out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose)
|
|
|
|
|
|
|
|
assert S1[0][0] == S2[0][0]
|
|
|
|
assert S1[0][1] == S2[0][1]
|
2022-08-01 10:31:48 +00:00
|
|
|
# print(out1)
|
|
|
|
# print(out2)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(out1, out2)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
n = 2
|
2022-08-01 10:31:48 +00:00
|
|
|
# dim1 = torch.randint(2,1024, size=(n,)).tolist()
|
|
|
|
# dim2 = torch.randint(2,1024, size=(n,)).tolist()
|
2022-07-22 21:41:05 +00:00
|
|
|
dim1 = [1]
|
|
|
|
dim2 = [33]
|
|
|
|
|
|
|
|
dtype = [torch.int8]
|
2022-08-01 10:31:48 +00:00
|
|
|
# a_order = ['col_turing', 'col_ampere']
|
|
|
|
a_order = ["col_turing"]
|
|
|
|
out_order = ["row"]
|
|
|
|
values = list(product(dim1, dim2, dtype, a_order, out_order))
|
|
|
|
names = [
|
2022-10-27 11:14:13 +00:00
|
|
|
"dim1_{}_dim2_{}_dtype_{}_orderA_{}_orderOut_{}".format(*vals)
|
2022-08-01 10:31:48 +00:00
|
|
|
for vals in values
|
|
|
|
]
|
|
|
|
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
def test_overflow():
|
|
|
|
formatB = F.get_special_format_str()
|
2022-07-25 21:02:14 +00:00
|
|
|
print(formatB)
|
2022-07-22 21:41:05 +00:00
|
|
|
for i in range(2):
|
2022-08-01 10:31:48 +00:00
|
|
|
a = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
|
|
|
|
b = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
Ca, Sa = F.nvidia_transform(a, "col32")
|
2022-07-22 21:41:05 +00:00
|
|
|
Cb, Sb = F.nvidia_transform(b, formatB)
|
|
|
|
|
|
|
|
c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8)
|
|
|
|
c2 = torch.matmul(a.float(), b.float().t())
|
|
|
|
|
|
|
|
|
|
|
|
n = 2
|
2022-08-01 10:31:48 +00:00
|
|
|
dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
|
|
|
|
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
|
|
|
|
# dim1 = [4]
|
|
|
|
# dim2 = [5]
|
|
|
|
|
|
|
|
values = list(product(dim1, dim2))
|
2022-10-27 11:14:13 +00:00
|
|
|
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
|
|
|
|
def test_coo_double_quant(dim1, dim2):
|
|
|
|
threshold = 3.00
|
|
|
|
for i in range(k):
|
2022-08-01 10:31:48 +00:00
|
|
|
A = torch.randn(dim1, dim2, device="cuda").half()
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
idx = torch.abs(A) >= threshold
|
2022-07-22 21:41:05 +00:00
|
|
|
CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
|
2022-08-01 16:32:47 +00:00
|
|
|
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
|
|
|
|
A, threshold=threshold
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
if coo_tensor is not None:
|
2022-08-01 10:31:48 +00:00
|
|
|
A1 = A * idx
|
2022-07-22 21:41:05 +00:00
|
|
|
A2 = torch.zeros_like(A)
|
2022-08-01 16:32:47 +00:00
|
|
|
A2[
|
|
|
|
coo_tensor.rowidx.long(), coo_tensor.colidx.long()
|
|
|
|
] = coo_tensor.values
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(A1, A2)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
A1 = A * (idx == 0)
|
|
|
|
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(
|
2022-08-01 16:32:47 +00:00
|
|
|
A * (idx == 0), A2, rtol=0.05, atol=1.5e-2
|
|
|
|
)
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
n = 2
|
2022-08-01 10:31:48 +00:00
|
|
|
dim1 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
|
|
|
|
dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
|
|
|
|
# dim1 = [7]
|
|
|
|
# dim2 = [11]
|
2022-07-22 21:41:05 +00:00
|
|
|
transposed_B = [False, True]
|
2022-08-01 10:31:48 +00:00
|
|
|
values = list(product(dim1, dim2, transposed_B))
|
2022-10-27 11:14:13 +00:00
|
|
|
names = ["dim1_{}_dim2_{}_transposed_B_{}".format(*vals) for vals in values]
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names)
|
|
|
|
def test_spmm_coo(dim1, dim2, transposed_B):
|
|
|
|
threshold = 1.5
|
|
|
|
dim3 = torch.randint(32, 128, size=(1,)).item()
|
2022-08-01 10:31:48 +00:00
|
|
|
# dim3 = 17
|
2022-07-22 21:41:05 +00:00
|
|
|
for i in range(k):
|
|
|
|
A = torch.randn(dim1, dim2).cuda().half()
|
|
|
|
if transposed_B:
|
|
|
|
B = torch.randn(dim3, dim2).cuda().half()
|
|
|
|
else:
|
|
|
|
B = torch.randn(dim2, dim3).cuda().half()
|
|
|
|
|
|
|
|
idx = torch.abs(A) >= threshold
|
|
|
|
nnz = (idx == 1).sum().item()
|
|
|
|
rows, cols = torch.where(idx)
|
|
|
|
values = A[idx]
|
2022-08-01 10:31:48 +00:00
|
|
|
cooA = F.COOSparseTensor(
|
|
|
|
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
|
|
|
|
)
|
|
|
|
A2 = A * idx
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
if transposed_B:
|
|
|
|
out2 = F.spmm_coo(cooA, B.t())
|
|
|
|
out1 = torch.matmul(A2, B.t())
|
|
|
|
else:
|
|
|
|
out2 = F.spmm_coo(cooA, B)
|
|
|
|
out1 = torch.matmul(A2, B)
|
|
|
|
|
|
|
|
assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)
|
|
|
|
|
|
|
|
|
|
|
|
def test_spmm_bench():
|
|
|
|
batch = 2
|
2022-08-01 10:31:48 +00:00
|
|
|
model = 1024 * 1
|
|
|
|
hidden = model * 4
|
2022-07-22 21:41:05 +00:00
|
|
|
seq = 1024
|
2022-08-01 10:31:48 +00:00
|
|
|
dim1 = batch * seq
|
2022-07-22 21:41:05 +00:00
|
|
|
dim2 = model
|
|
|
|
dim3 = hidden
|
|
|
|
threshold = 4
|
2022-08-01 10:31:48 +00:00
|
|
|
A = torch.randn(dim1, dim2, device="cuda").half()
|
|
|
|
B = torch.randn(dim2, dim3, device="cuda").half()
|
2021-10-06 02:16:20 +00:00
|
|
|
for i in range(10):
|
2022-10-24 18:54:25 +00:00
|
|
|
C1 = bnb.matmul(A, B.t())
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
t0 = time.time()
|
|
|
|
for i in range(k):
|
2022-10-24 18:54:25 +00:00
|
|
|
C1 = bnb.matmul(A, B.t())
|
2022-07-22 21:41:05 +00:00
|
|
|
torch.cuda.synchronize()
|
2022-08-01 10:31:48 +00:00
|
|
|
t8 = time.time() - t0
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
idx = torch.abs(A) >= threshold
|
|
|
|
nnz = (idx == 1).sum().item()
|
2022-08-01 10:31:48 +00:00
|
|
|
print(nnz / idx.numel())
|
2022-07-22 21:41:05 +00:00
|
|
|
rows, cols = torch.where(idx)
|
|
|
|
values = A[idx]
|
2022-08-01 10:31:48 +00:00
|
|
|
cooA = F.COOSparseTensor(
|
|
|
|
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
for i in range(10):
|
2022-07-22 21:41:05 +00:00
|
|
|
out2 = F.spmm_coo(cooA, B)
|
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
t0 = time.time()
|
|
|
|
for i in range(k):
|
|
|
|
out2 = F.spmm_coo(cooA, B)
|
|
|
|
torch.cuda.synchronize()
|
2022-08-01 10:31:48 +00:00
|
|
|
tsp = time.time() - t0
|
2022-07-22 21:41:05 +00:00
|
|
|
print(tsp, t8)
|
2022-08-01 10:31:48 +00:00
|
|
|
print(tsp / t8)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
|
|
|
|
n = 2
|
2022-08-01 10:31:48 +00:00
|
|
|
dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
|
|
|
|
dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
|
|
|
|
values = list(product(dim1, dim2))
|
2022-10-27 11:14:13 +00:00
|
|
|
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
|
|
|
|
def test_integrated_sparse_decomp(dim1, dim2):
|
|
|
|
threshold = 3.0
|
2022-08-01 10:31:48 +00:00
|
|
|
formatB = "col_turing"
|
2022-07-22 21:41:05 +00:00
|
|
|
for i in range(k):
|
|
|
|
A = torch.randn(dim1, dim2).cuda().half()
|
|
|
|
w1 = torch.randn(dim1, dim2).cuda().half()
|
|
|
|
out1 = torch.matmul(A, w1.t())
|
|
|
|
|
|
|
|
Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
|
|
|
|
CTw1, Sw1 = F.transform(Cw1, formatB)
|
|
|
|
|
|
|
|
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
|
2022-08-01 10:31:48 +00:00
|
|
|
C32A, SA = F.transform(CA, "col32")
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
|
|
|
|
out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
|
|
|
|
|
2022-08-01 16:32:47 +00:00
|
|
|
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
|
|
|
|
A, threshold=threshold
|
|
|
|
)
|
2022-08-01 10:31:48 +00:00
|
|
|
C32A, SA = F.transform(CA, "col32")
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
|
|
|
|
out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
|
|
|
|
|
|
|
|
assert coo_tensor is not None
|
|
|
|
|
|
|
|
out4 = F.spmm_coo(coo_tensor, w1.t())
|
|
|
|
out5 = out3 + out4
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
err1 = torch.abs(out1 - out2).mean().item()
|
|
|
|
err2 = torch.abs(out1 - out5).mean().item()
|
2022-07-22 21:41:05 +00:00
|
|
|
assert err2 < err1
|
|
|
|
|
|
|
|
|
|
|
|
def test_matmuls():
|
2022-10-24 18:54:25 +00:00
|
|
|
a = torch.randn(256, 512).half().cuda()
|
|
|
|
b = torch.randn(256, 512).half().cuda()
|
|
|
|
c1 = torch.matmul(a, b.t())
|
2022-07-22 21:41:05 +00:00
|
|
|
c2 = bnb.matmul(a, b)
|
2022-10-24 18:54:25 +00:00
|
|
|
c3 = bnb.matmul_cublas(a, b.t())
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
err1 = torch.abs(c1 - c2).mean().item()
|
|
|
|
err2 = torch.abs(c1 - c3).mean().item()
|
2022-07-22 21:41:05 +00:00
|
|
|
assert err1 < 0.2
|
|
|
|
assert err2 < 0.2
|
2022-10-24 18:54:25 +00:00
|
|
|
print(err1, err2)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
|
|
|
|
n = 2
|
2022-08-01 10:31:48 +00:00
|
|
|
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
|
|
|
|
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
|
|
|
|
dim1 = [1 * 2048]
|
2022-07-22 21:41:05 +00:00
|
|
|
dim2 = [12288]
|
2022-08-01 10:31:48 +00:00
|
|
|
# dim1 = [32]
|
|
|
|
# dim2 = [32]
|
|
|
|
# dtype = [torch.float16, torch.int8]
|
2022-07-22 21:41:05 +00:00
|
|
|
dtype = [torch.float16]
|
2022-08-01 10:31:48 +00:00
|
|
|
out_function = ["zeros", "ones"]
|
|
|
|
values = list(product(dim1, dim2, dtype, out_function))
|
2022-08-01 16:32:47 +00:00
|
|
|
names = [
|
2022-10-27 11:14:13 +00:00
|
|
|
"dim1_{}_dim2_{}_dtype_{}_out_func_{}".format(*vals) for vals in values
|
2022-08-01 16:32:47 +00:00
|
|
|
]
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names)
|
|
|
|
def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
|
|
|
|
out_func = getattr(torch, out_func)
|
|
|
|
|
|
|
|
threshold = 3.3
|
2022-08-01 10:31:48 +00:00
|
|
|
# threshold = 2.8
|
|
|
|
# threshold = 0.0
|
|
|
|
A = torch.randn(dim1, dim2, device="cuda").half()
|
2022-07-22 21:41:05 +00:00
|
|
|
if dtype == torch.float16:
|
2022-08-01 10:31:48 +00:00
|
|
|
B = torch.randn(dim2, dim2 * 4, device="cuda").half()
|
2022-07-22 21:41:05 +00:00
|
|
|
torch.nn.init.xavier_uniform_(B)
|
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
B = torch.randn(dim2, dim2 * 4, device="cuda").half()
|
2022-07-22 21:41:05 +00:00
|
|
|
torch.nn.init.xavier_uniform_(B)
|
2022-08-01 10:31:48 +00:00
|
|
|
B, SB = F.vectorwise_quant(B, quant_type="linear")
|
|
|
|
# B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
print("")
|
2022-07-22 21:41:05 +00:00
|
|
|
idx = torch.abs(A) >= threshold
|
|
|
|
nnz = (idx == 1).sum().item()
|
|
|
|
rows, cols = torch.where(idx)
|
|
|
|
values = A[idx]
|
2022-08-01 10:31:48 +00:00
|
|
|
cooA = F.COOSparseTensor(
|
|
|
|
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
|
|
|
|
)
|
|
|
|
A2 = A * idx
|
2022-07-22 21:41:05 +00:00
|
|
|
out1 = torch.matmul(A2.half(), B.half())
|
|
|
|
out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
|
|
|
|
out1 += out.clone()
|
|
|
|
out2 = F.spmm_coo_very_sparse(cooA, B, out=out)
|
2022-08-01 10:31:48 +00:00
|
|
|
# print(B)
|
|
|
|
# print(out1)
|
|
|
|
# print(out2)
|
|
|
|
p = 200 / (2048 * 12288 * 4)
|
2022-07-22 21:41:05 +00:00
|
|
|
n = out1.numel()
|
2022-08-01 10:31:48 +00:00
|
|
|
count = math.ceil(p * n)
|
2022-07-22 21:41:05 +00:00
|
|
|
std = out1.std()
|
|
|
|
out1 /= std
|
|
|
|
out2 /= std
|
2022-08-01 16:32:47 +00:00
|
|
|
assert_all_approx_close(
|
|
|
|
out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count
|
|
|
|
)
|
2022-08-01 10:31:48 +00:00
|
|
|
# assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
idx_col = torch.randint(0, A2.shape[-1], size=(15,))
|
|
|
|
|
2023-05-07 20:34:03 +00:00
|
|
|
# torch.testing.assert_close(out1, out2.half(), rtol=0.05, atol=0.001)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
# Bt = torch.randn(dim2*4, dim2, device='cuda').half()
|
|
|
|
# torch.cuda.synchronize()
|
|
|
|
# t0 = time.time()
|
|
|
|
# print(A2.shape, B.shape)
|
|
|
|
# for i in range(100):
|
2022-07-22 21:41:05 +00:00
|
|
|
# #out3 = F.spmm_coo(cooA, Bt.t())
|
|
|
|
# #out2 = F.spmm_coo(cooA, B)
|
|
|
|
# #out2 = F.spmm_coo_very_sparse(cooA, B)
|
|
|
|
# #out1 = torch.matmul(A, Bt.t())
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
# torch.cuda.synchronize()
|
|
|
|
# print(time.time() - t0)
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
def test_coo2csr():
|
|
|
|
threshold = 1
|
|
|
|
A = torch.randn(128, 128).half().cuda()
|
|
|
|
idx = torch.abs(A) >= threshold
|
|
|
|
nnz = (idx == 1).sum().item()
|
|
|
|
rows, cols = torch.where(idx)
|
|
|
|
values = A[idx]
|
2022-08-01 10:31:48 +00:00
|
|
|
cooA = F.COOSparseTensor(
|
|
|
|
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
|
|
|
|
)
|
|
|
|
A2 = A * idx
|
2022-07-22 21:41:05 +00:00
|
|
|
csrA = F.coo2csr(cooA)
|
|
|
|
counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
|
|
|
|
assert counts.numel() == A.shape[0]
|
|
|
|
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(counts.long(), (A2 != 0).sum(1))
|
2022-08-01 10:31:48 +00:00
|
|
|
idx = A2 != 0
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(A2[idx], csrA.values)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_coo2csc():
|
|
|
|
threshold = 1
|
|
|
|
A = torch.randn(128, 128).half().cuda()
|
|
|
|
idx = torch.abs(A) >= threshold
|
|
|
|
nnz = (idx == 1).sum().item()
|
|
|
|
rows, cols = torch.where(idx)
|
|
|
|
values = A[idx]
|
2022-08-01 10:31:48 +00:00
|
|
|
cooA = F.COOSparseTensor(
|
|
|
|
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
|
|
|
|
)
|
|
|
|
A2 = A * idx
|
2022-07-22 21:41:05 +00:00
|
|
|
cscA = F.coo2csc(cooA)
|
|
|
|
counts = cscA.colptr[1:] - cscA.colptr[:-1]
|
|
|
|
assert counts.numel() == A.shape[1]
|
|
|
|
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(counts.long(), (A2 != 0).sum(0))
|
2022-07-22 21:41:05 +00:00
|
|
|
# torch uses row-major -> use transpose to transfer to col-major
|
2022-08-01 10:31:48 +00:00
|
|
|
idx = A2.t() != 0
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(A2.t()[idx], cscA.values)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
|
|
|
|
n = 2
|
2022-08-01 10:31:48 +00:00
|
|
|
# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
|
|
|
|
# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
|
|
|
|
dim1 = [1 * 2048]
|
|
|
|
# dim2 = [12288]
|
2022-07-22 21:41:05 +00:00
|
|
|
dim2 = [2048]
|
2022-08-01 10:31:48 +00:00
|
|
|
# dim1 = [2]
|
|
|
|
# dim2 = [2]
|
2022-07-22 21:41:05 +00:00
|
|
|
dtype = [torch.int8]
|
2022-08-01 10:31:48 +00:00
|
|
|
values = list(product(dim1, dim2, dtype))
|
2022-10-27 11:14:13 +00:00
|
|
|
names = ["dim1_{}_dim2_{}_dtype_{}".format(*vals) for vals in values]
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names)
|
|
|
|
def test_spmm_coo_dequant(dim1, dim2, dtype):
|
|
|
|
threshold = 6.0
|
2022-08-01 10:31:48 +00:00
|
|
|
# threshold = 2.8
|
|
|
|
# threshold = 0.0
|
|
|
|
A = torch.randn(dim1, dim2, device="cuda").half()
|
|
|
|
B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16)
|
2022-07-22 21:41:05 +00:00
|
|
|
torch.nn.init.xavier_uniform_(B)
|
|
|
|
Bt = B.t().contiguous()
|
|
|
|
|
|
|
|
CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
|
|
|
|
|
|
|
|
rowidx = torch.randint(0, A.shape[-1], size=(15,))
|
|
|
|
|
|
|
|
A[:, rowidx] = 8.0
|
|
|
|
|
|
|
|
idx = torch.abs(A) >= threshold
|
|
|
|
nnz = (idx == 1).sum().item()
|
|
|
|
rows, cols = torch.where(idx)
|
|
|
|
values = A[idx]
|
2022-08-01 10:31:48 +00:00
|
|
|
cooA = F.COOSparseTensor(
|
|
|
|
A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
|
|
|
|
)
|
|
|
|
A2 = A * idx
|
2022-07-22 21:41:05 +00:00
|
|
|
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
|
|
|
|
out1 = torch.matmul(A2, B.half())
|
|
|
|
out3 = F.spmm_coo_very_sparse(cooA, CBt.half())
|
2022-08-01 10:31:48 +00:00
|
|
|
out3 = out3 * statsBt.half() / 127
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
values, counts = torch.unique(cooA.rowidx, return_counts=True)
|
|
|
|
offset = counts.cumsum(0).int()
|
|
|
|
max_count, max_idx = torch.sort(counts, descending=True)
|
|
|
|
print(torch.median(max_count.float()))
|
|
|
|
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(out2, out3, rtol=0.05, atol=0.001)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
p = 200 / (2048 * 12288 * 4)
|
2022-07-22 21:41:05 +00:00
|
|
|
n = out1.numel()
|
2022-08-01 10:31:48 +00:00
|
|
|
count = math.ceil(p * n)
|
2022-07-22 21:41:05 +00:00
|
|
|
assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
# torch.cuda.synchronize()
|
|
|
|
# t0 = time.time()
|
|
|
|
# for i in range(100):
|
2022-07-22 21:41:05 +00:00
|
|
|
# out2 = F.spmm_coo_very_sparse(cooA, B)
|
2022-08-01 10:31:48 +00:00
|
|
|
# torch.cuda.synchronize()
|
|
|
|
# print('fp16', time.time() - t0)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
t0 = time.time()
|
|
|
|
for i in range(100):
|
2022-08-01 10:31:48 +00:00
|
|
|
out2 = F.spmm_coo(cooA, B)
|
2022-07-22 21:41:05 +00:00
|
|
|
torch.cuda.synchronize()
|
2022-08-01 10:31:48 +00:00
|
|
|
print("cusparse fp16", time.time() - t0)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
t0 = time.time()
|
|
|
|
for i in range(100):
|
2022-08-01 10:31:48 +00:00
|
|
|
out2 = F.spmm_coo_very_sparse(cooA, CBt)
|
2022-07-22 21:41:05 +00:00
|
|
|
torch.cuda.synchronize()
|
2022-08-01 10:31:48 +00:00
|
|
|
print("int8", time.time() - t0)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
t0 = time.time()
|
|
|
|
for i in range(100):
|
2022-08-01 10:31:48 +00:00
|
|
|
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
|
2022-07-22 21:41:05 +00:00
|
|
|
torch.cuda.synchronize()
|
2022-08-01 10:31:48 +00:00
|
|
|
print("int8+dequant", time.time() - t0)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
t0 = time.time()
|
|
|
|
for i in range(100):
|
2022-08-01 10:31:48 +00:00
|
|
|
out2 = torch.matmul(A, B)
|
2022-07-22 21:41:05 +00:00
|
|
|
torch.cuda.synchronize()
|
2022-08-01 10:31:48 +00:00
|
|
|
print("matmul", time.time() - t0)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
t0 = time.time()
|
|
|
|
for i in range(100):
|
|
|
|
out1 = bnb.matmul(A, Bt)
|
|
|
|
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
|
2022-08-01 10:31:48 +00:00
|
|
|
out = out1 + out2
|
2022-07-22 21:41:05 +00:00
|
|
|
torch.cuda.synchronize()
|
2022-08-01 10:31:48 +00:00
|
|
|
print("sparse+ matmul", time.time() - t0)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
t0 = time.time()
|
|
|
|
for i in range(100):
|
|
|
|
out1 = bnb.matmul(A, Bt)
|
|
|
|
torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1)
|
|
|
|
torch.cuda.synchronize()
|
2022-08-01 10:31:48 +00:00
|
|
|
print("partial matmul", time.time() - t0)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
t0 = time.time()
|
|
|
|
for i in range(100):
|
|
|
|
out1 = bnb.matmul(A, Bt)
|
|
|
|
torch.cuda.synchronize()
|
2022-08-01 10:31:48 +00:00
|
|
|
print("partial matmul", time.time() - t0)
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2023-07-05 02:58:31 +00:00
|
|
|
batch_size = 1
|
|
|
|
seqdim = 1
|
2022-07-22 21:41:05 +00:00
|
|
|
values = []
|
2023-05-24 01:42:19 +00:00
|
|
|
#values.append((batch_size, seqdim, 768, 4 * 768))
|
2023-05-07 20:34:03 +00:00
|
|
|
#values.append((batch_size, seqdim, 1024, 4*1024))
|
|
|
|
#values.append((batch_size, seqdim, 1536, 4*1536))
|
|
|
|
#values.append((batch_size, seqdim, 2048, 4*2048))
|
|
|
|
#values.append((batch_size, seqdim, 2560, 4*2560))
|
2023-07-04 01:45:38 +00:00
|
|
|
#values.append((batch_size, seqdim, 4096, 4*4096))
|
|
|
|
#values.append((batch_size, seqdim, 5120, 4*5120))
|
|
|
|
#values.append((batch_size, seqdim, 6656, 4*6656))
|
2023-05-24 01:42:19 +00:00
|
|
|
values.append((batch_size, seqdim, 8192, 4*8192))
|
2023-05-07 20:34:03 +00:00
|
|
|
#values.append((batch_size, seqdim, 5140, 4*5140))
|
2022-09-11 18:55:09 +00:00
|
|
|
#values.append((batch_size, seqdim, 12288, 4*12288))
|
2023-02-05 06:00:04 +00:00
|
|
|
names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values]
|
2022-07-22 21:41:05 +00:00
|
|
|
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
|
|
|
|
def test_bench_matmul(batch, seq, model, hidden):
|
2023-05-24 01:42:19 +00:00
|
|
|
iters = 80
|
2022-07-22 21:41:05 +00:00
|
|
|
formatB = F.get_special_format_str()
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
A = torch.randn(batch, seq, model, device="cuda").half()
|
|
|
|
B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
|
2022-07-22 21:41:05 +00:00
|
|
|
torch.nn.init.xavier_uniform_(B)
|
|
|
|
|
2023-02-05 06:00:04 +00:00
|
|
|
B_fp4, state = F.quantize_fp4(B)
|
2023-04-01 23:10:18 +00:00
|
|
|
B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True)
|
2023-02-05 06:00:04 +00:00
|
|
|
|
2023-07-05 02:58:31 +00:00
|
|
|
B_nf4, state_nf4 = F.quantize_nf4(B)
|
2023-04-03 18:00:12 +00:00
|
|
|
|
2023-05-24 01:42:19 +00:00
|
|
|
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half()
|
2022-07-22 21:41:05 +00:00
|
|
|
linear8bit.eval()
|
|
|
|
|
|
|
|
outliers = torch.randint(0, model, size=(5,)).cuda()
|
|
|
|
A[:, :, outliers] = 8.0
|
|
|
|
|
2023-05-24 01:42:19 +00:00
|
|
|
linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half())
|
|
|
|
#linearMixedBit.eval()
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2023-02-05 06:00:04 +00:00
|
|
|
linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
|
|
|
|
linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
|
2023-07-05 02:58:31 +00:00
|
|
|
F.cutlass3_gemm(A, B_nf4.t(), state=state_nf4)
|
2023-02-05 06:00:04 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
# warmup
|
2022-08-23 20:59:34 +00:00
|
|
|
for i in range(iters):
|
2022-07-22 21:41:05 +00:00
|
|
|
torch.matmul(A, B.t())
|
|
|
|
torch.cuda.synchronize()
|
2022-08-01 10:31:48 +00:00
|
|
|
print("")
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
t0 = time.time()
|
2022-08-23 20:59:34 +00:00
|
|
|
for i in range(iters):
|
2022-07-22 21:41:05 +00:00
|
|
|
torch.matmul(A, B.t())
|
|
|
|
torch.cuda.synchronize()
|
2023-02-05 06:00:04 +00:00
|
|
|
print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
|
|
|
|
2023-07-04 01:45:38 +00:00
|
|
|
#torch.cuda.synchronize()
|
|
|
|
#t0 = time.time()
|
|
|
|
#for i in range(iters):
|
|
|
|
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
|
|
|
|
#torch.cuda.synchronize()
|
|
|
|
#print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2023-07-04 01:45:38 +00:00
|
|
|
#torch.cuda.synchronize()
|
|
|
|
#t0 = time.time()
|
|
|
|
#for i in range(iters):
|
|
|
|
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
|
|
|
|
#torch.cuda.synchronize()
|
|
|
|
#print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
2023-04-01 23:10:18 +00:00
|
|
|
|
2023-04-03 18:00:12 +00:00
|
|
|
torch.cuda.synchronize()
|
|
|
|
t0 = time.time()
|
|
|
|
for i in range(iters):
|
2023-07-05 02:58:31 +00:00
|
|
|
#bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
|
|
|
|
F.cutlass3_gemm(A, B_nf4.t(), state=state_nf4)
|
2023-04-03 18:00:12 +00:00
|
|
|
torch.cuda.synchronize()
|
|
|
|
print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
|
|
|
|
2023-03-27 16:12:57 +00:00
|
|
|
#torch.cuda.synchronize()
|
|
|
|
#t0 = time.time()
|
|
|
|
#for i in range(iters):
|
|
|
|
# bnb.matmul(A, B)
|
|
|
|
#torch.cuda.synchronize()
|
|
|
|
#print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
|
|
|
|
|
|
|
#torch.cuda.synchronize()
|
|
|
|
#t0 = time.time()
|
|
|
|
#for i in range(iters):
|
|
|
|
# bnb.matmul(A, B, threshold=6.0)
|
|
|
|
#torch.cuda.synchronize()
|
|
|
|
#print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
|
|
|
|
|
|
|
#CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
|
|
|
|
#C32A, SA = F.transform(CA, "col32")
|
|
|
|
#CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
|
|
|
|
#CxB, SB = F.transform(CB, to_order=formatB)
|
|
|
|
#torch.cuda.synchronize()
|
|
|
|
#t0 = time.time()
|
|
|
|
#for i in range(iters):
|
|
|
|
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
|
|
|
|
#torch.cuda.synchronize()
|
|
|
|
#print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
|
|
|
|
|
|
|
#BA, statsB = F.vectorwise_quant(B, dim=1)
|
|
|
|
#CxB, SB = F.nvidia_transform(CB, to_order=formatB)
|
|
|
|
#torch.cuda.synchronize()
|
|
|
|
#t0 = time.time()
|
|
|
|
#for i in range(iters):
|
|
|
|
# A2 = A.view(-1, A.shape[-1]).contiguous()
|
|
|
|
# CA, statsA = F.vectorwise_quant(A2, dim=1)
|
|
|
|
# C32A, SA = F.nvidia_transform(CA, "col32")
|
|
|
|
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
|
|
|
|
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
|
|
|
|
# F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
|
|
|
|
#torch.cuda.synchronize()
|
|
|
|
#print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
|
|
|
|
|
|
|
#BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
|
|
|
|
#CxB, SB = F.nvidia_transform(CB, to_order=formatB)
|
|
|
|
#torch.cuda.synchronize()
|
|
|
|
#t0 = time.time()
|
|
|
|
#for i in range(iters):
|
|
|
|
# A2 = A.view(-1, A.shape[-1]).contiguous()
|
|
|
|
# CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
|
|
|
|
# C32A, SA = F.nvidia_transform(CA, "col32")
|
|
|
|
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
|
|
|
|
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
|
|
|
|
# out = Cout * statsB * statsA * (1.0 / (127 * 127))
|
|
|
|
#torch.cuda.synchronize()
|
|
|
|
#print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
|
|
|
|
2023-07-04 01:45:38 +00:00
|
|
|
#linear8bit(A)
|
|
|
|
#torch.cuda.synchronize()
|
|
|
|
#t0 = time.time()
|
|
|
|
#for i in range(iters):
|
|
|
|
# linear8bit(A)
|
|
|
|
#torch.cuda.synchronize()
|
|
|
|
#print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
2023-03-27 16:12:57 +00:00
|
|
|
|
2023-07-04 01:45:38 +00:00
|
|
|
#linearMixedBit(A)
|
|
|
|
#torch.cuda.synchronize()
|
|
|
|
#t0 = time.time()
|
|
|
|
#for i in range(iters):
|
|
|
|
# linearMixedBit(A)
|
|
|
|
#torch.cuda.synchronize()
|
|
|
|
#print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
2023-03-27 16:12:57 +00:00
|
|
|
|
|
|
|
#linear8bit_train(A)
|
|
|
|
#torch.cuda.synchronize()
|
|
|
|
#t0 = time.time()
|
|
|
|
#for i in range(iters):
|
|
|
|
# linear8bit_train(A)
|
|
|
|
#torch.cuda.synchronize()
|
|
|
|
#print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
|
|
|
|
|
|
|
#linear8bit_train_thresh(A)
|
|
|
|
#torch.cuda.synchronize()
|
|
|
|
#t0 = time.time()
|
|
|
|
#for i in range(iters):
|
|
|
|
# linear8bit_train(A)
|
|
|
|
#torch.cuda.synchronize()
|
|
|
|
#print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
def test_zeropoint():
|
|
|
|
def quant_zp(x):
|
|
|
|
dtype = x.dtype
|
|
|
|
x = x.float()
|
|
|
|
dyna = x.max() - x.min()
|
2022-08-01 10:31:48 +00:00
|
|
|
if dyna == 0:
|
|
|
|
dyna = 1
|
|
|
|
qx = 254.0 / dyna
|
2022-07-22 21:41:05 +00:00
|
|
|
minx = x.min()
|
2022-08-01 10:31:48 +00:00
|
|
|
# zpx = torch.round(minx* qx)
|
|
|
|
# zpx = 127 - torch.round(x.max()* qx)
|
|
|
|
zpx = torch.round(x.min() * qx) - 127
|
|
|
|
x = (qx * x) + zpx
|
2022-07-22 21:41:05 +00:00
|
|
|
return x, qx, zpx
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
batch = 2
|
|
|
|
seq = 512
|
|
|
|
model = 1024
|
2022-08-01 10:31:48 +00:00
|
|
|
hidden = 4 * model
|
|
|
|
A = torch.randn(batch * seq, model, device="cuda").half() * 0.1
|
|
|
|
B = torch.randn(model, hidden, device="cuda").half() * 0.1
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
C0 = torch.matmul(A, B)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
# A, SA = F.vectorwise_quant(A, quant_type='linear')
|
|
|
|
# B, SB = F.vectorwise_quant(B, quant_type='linear')
|
2022-07-22 21:41:05 +00:00
|
|
|
A = A.float()
|
|
|
|
B = B.float()
|
|
|
|
|
|
|
|
C1 = torch.matmul(A, B)
|
|
|
|
C3 = bnb.matmul(A.half(), B.t().contiguous().half())
|
|
|
|
|
|
|
|
zp = 1
|
2022-08-01 10:31:48 +00:00
|
|
|
# C2 = torch.matmul(A-zp, B)
|
|
|
|
# C2 += B.sum(0).view(1, -1)*zp
|
|
|
|
C2 = torch.matmul(A, B - zp)
|
|
|
|
C2 -= A.sum(1).view(-1, 1) * zp
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
ca, cqa, cza = quant_zp(A)
|
|
|
|
print(ca.min(), ca.max())
|
2022-08-01 10:31:48 +00:00
|
|
|
print((ca - cza).min(), (ca - cza).max())
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
zp = 1
|
|
|
|
scale = 2.0
|
2022-08-01 10:31:48 +00:00
|
|
|
C5 = torch.matmul((A * scale) - zp, B)
|
|
|
|
C5 += B.sum(0) * zp
|
2022-07-22 21:41:05 +00:00
|
|
|
C5 /= scale
|
|
|
|
|
|
|
|
CA, qa, zpa = quant_zp(A)
|
|
|
|
C4 = torch.matmul(CA, B)
|
2022-08-01 10:31:48 +00:00
|
|
|
C4 -= B.sum(0) * zpa
|
2022-07-22 21:41:05 +00:00
|
|
|
C4 /= qa
|
2021-10-06 02:16:20 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
zpb = 1
|
|
|
|
zpa = 1
|
|
|
|
qa = 2
|
|
|
|
qb = 2
|
2022-08-01 10:31:48 +00:00
|
|
|
C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb)
|
|
|
|
C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
|
|
|
|
C6 -= zpa * zpb * A.shape[1]
|
|
|
|
C6 /= qa * qb
|
2021-10-06 02:16:20 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
CA, qa, zpa = quant_zp(A)
|
|
|
|
CB, qb, zpb = quant_zp(B)
|
|
|
|
C7 = torch.matmul(CA, CB)
|
2022-08-01 10:31:48 +00:00
|
|
|
C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
|
|
|
|
C7 -= zpa * zpb * A.shape[1]
|
|
|
|
C7 /= qa * qb
|
2021-10-06 02:16:20 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
print("")
|
|
|
|
# print(C0.flatten()[:10])
|
2022-07-22 21:41:05 +00:00
|
|
|
print(C1.flatten()[:10])
|
|
|
|
print(C2.flatten()[:10])
|
|
|
|
print(C3.flatten()[:10])
|
|
|
|
print(C5.flatten()[:10])
|
|
|
|
print(C6.flatten()[:10])
|
|
|
|
print(C7.flatten()[:10])
|
2022-08-01 10:31:48 +00:00
|
|
|
err1 = torch.abs(C1 - C2).mean().item()
|
|
|
|
err2 = torch.abs(C1 - C3).mean().item()
|
|
|
|
err3 = torch.abs(C1 - C4).mean().item()
|
|
|
|
err4 = torch.abs(C1 - C5).mean().item()
|
|
|
|
err5 = torch.abs(C1 - C6).mean().item()
|
|
|
|
err6 = torch.abs(C1 - C7).mean().item()
|
2022-07-22 21:41:05 +00:00
|
|
|
print(err1, err2, err3, err4, err5, err6)
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
|
2022-07-26 19:12:38 +00:00
|
|
|
def test_extract_outliers():
|
2022-07-27 00:39:30 +00:00
|
|
|
for i in range(k):
|
2022-08-01 10:31:48 +00:00
|
|
|
shapeA = (4096, 4096 * 4)
|
2022-07-27 00:39:30 +00:00
|
|
|
idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
|
2022-08-01 10:31:48 +00:00
|
|
|
# idx = torch.Tensor([0]).int().cuda()
|
|
|
|
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
|
2022-07-27 00:39:30 +00:00
|
|
|
outliers1 = A[:, idx.long()]
|
2022-07-26 19:12:38 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
CA, SA = F.transform(A, "col_turing")
|
2022-07-26 19:12:38 +00:00
|
|
|
|
2022-07-27 00:39:30 +00:00
|
|
|
outliers2 = F.extract_outliers(CA, SA, idx)
|
2022-07-26 19:12:38 +00:00
|
|
|
|
2022-07-27 00:39:30 +00:00
|
|
|
assert outliers2.shape[0] == shapeA[0]
|
|
|
|
assert outliers2.shape[1] == idx.numel()
|
2022-07-26 19:12:38 +00:00
|
|
|
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(outliers1, outliers2)
|
2022-07-27 01:15:51 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
CA, SA = F.transform(A, "col_ampere")
|
2022-07-27 01:15:51 +00:00
|
|
|
|
|
|
|
outliers2 = F.extract_outliers(CA, SA, idx)
|
|
|
|
|
|
|
|
assert outliers2.shape[0] == shapeA[0]
|
|
|
|
assert outliers2.shape[1] == idx.numel()
|
2022-07-26 19:12:38 +00:00
|
|
|
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(outliers1, outliers2)
|
2022-09-11 18:55:09 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_blockwise_cpu_large():
|
|
|
|
diffs = []
|
|
|
|
reldiffs = []
|
|
|
|
batch = 128
|
|
|
|
seq = 128
|
2022-10-24 18:54:25 +00:00
|
|
|
for hidden in [128]:#, 14336]:
|
2022-09-13 17:37:53 +00:00
|
|
|
for blocksize in [4096, 16384]:
|
|
|
|
for i in range(2):
|
|
|
|
A1 = torch.randn(batch, seq, hidden, device='cpu')
|
|
|
|
t0 = time.time()
|
|
|
|
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
|
|
|
|
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
|
|
|
|
print(time.time() - t0)
|
|
|
|
diff = torch.abs(A1 - A2)
|
|
|
|
reldiff = diff / torch.abs(A1 + 1e-8)
|
|
|
|
diffs.append(diff.mean().item())
|
|
|
|
reldiffs.append(reldiff.mean().item())
|
|
|
|
assert diffs[-1] < 0.011
|
|
|
|
# print(sum(diffs)/len(diffs))
|
|
|
|
# print(sum(reldiffs)/len(reldiffs))
|
2022-11-04 02:49:50 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_fp8_quant():
|
|
|
|
for e_bits in range(1, 7):
|
|
|
|
p_bits = 7-e_bits
|
|
|
|
code = F.create_fp8_map(True, e_bits, p_bits).cuda()
|
|
|
|
|
|
|
|
abserr = []
|
|
|
|
relerr = []
|
|
|
|
for i in range(100):
|
|
|
|
A1 = torch.randn(1024, 1024, device="cuda")
|
|
|
|
C, SC = F.quantize_blockwise(A1, code=code)
|
|
|
|
A2 = F.dequantize_blockwise(C, SC)
|
|
|
|
diff = torch.abs(A1 - A2)
|
|
|
|
reldiff = diff/torch.abs(A1+1e-8)
|
|
|
|
abserr.append(diff.mean().item())
|
|
|
|
relerr.append(reldiff.mean().item())
|
|
|
|
#assert diff < 0.0075
|
2022-11-07 00:27:48 +00:00
|
|
|
#print(sum(abserr)/len(abserr))
|
|
|
|
#print(sum(relerr)/len(relerr))
|
2022-11-04 02:49:50 +00:00
|
|
|
|
|
|
|
abserr = []
|
|
|
|
relerr = []
|
|
|
|
for i in range(100):
|
|
|
|
A1 = torch.rand(1024, 1024, device="cuda")
|
|
|
|
C, SC = F.quantize_blockwise(A1, code=code)
|
|
|
|
A2 = F.dequantize_blockwise(C, SC)
|
|
|
|
diff = torch.abs(A1 - A2)
|
|
|
|
reldiff = diff/torch.abs(A1+1e-8)
|
|
|
|
abserr.append(diff.mean().item())
|
|
|
|
relerr.append(reldiff.mean().item())
|
|
|
|
#assert diff < 0.0075
|
2022-11-07 00:27:48 +00:00
|
|
|
#print(sum(abserr)/len(abserr))
|
|
|
|
#print(sum(relerr)/len(relerr))
|
2022-11-04 02:49:50 +00:00
|
|
|
|
|
|
|
abserr = []
|
|
|
|
relerr = []
|
|
|
|
for i in range(100):
|
|
|
|
A1 = torch.randn(1024, 1024, device="cuda")
|
|
|
|
C, SC = F.quantize_blockwise(A1)
|
|
|
|
A2 = F.dequantize_blockwise(C, SC)
|
|
|
|
diff = torch.abs(A1 - A2)
|
|
|
|
reldiff = diff/torch.abs(A1+1e-8)
|
|
|
|
abserr.append(diff.mean().item())
|
|
|
|
relerr.append(reldiff.mean().item())
|
|
|
|
#assert diff < 0.0075
|
2022-11-07 00:27:48 +00:00
|
|
|
#print(3, sum(abserr)/len(abserr))
|
|
|
|
#print(3, sum(relerr)/len(relerr))
|
2022-11-04 02:49:50 +00:00
|
|
|
|
2022-11-06 19:47:54 +00:00
|
|
|
|
|
|
|
def test_few_bit_quant():
|
|
|
|
|
2022-11-07 00:27:48 +00:00
|
|
|
#print('')
|
2022-11-06 19:47:54 +00:00
|
|
|
for bits in range(2, 9):
|
2022-11-07 00:27:48 +00:00
|
|
|
#print('='*30, bits, '='*30)
|
2022-11-06 21:05:25 +00:00
|
|
|
for method in ['linear', 'fp8', 'dynamic', 'quantile']:
|
|
|
|
abserrs = []
|
|
|
|
relerrs = []
|
2022-11-06 19:59:37 +00:00
|
|
|
code = None
|
|
|
|
if method == 'linear':
|
2022-11-07 00:27:48 +00:00
|
|
|
code = F.create_linear_map(True, total_bits=bits).cuda()
|
2022-11-06 19:59:37 +00:00
|
|
|
elif method == 'fp8':
|
|
|
|
ebits = math.ceil(bits/2)
|
|
|
|
pbits = bits-ebits-1
|
|
|
|
code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
|
2022-11-06 21:05:25 +00:00
|
|
|
elif method == 'dynamic':
|
|
|
|
code = F.create_dynamic_map(True, bits-0, bits).cuda()
|
|
|
|
elif method == 'quantile':
|
|
|
|
values = torch.randn(2048, 2048, device='cuda')
|
2022-11-19 15:24:03 +00:00
|
|
|
code = F.create_quantile_map(values, bits).cuda()
|
|
|
|
# for some data types we have no zero
|
|
|
|
# for some data types we have one zero
|
|
|
|
# for some data types we have two zeros
|
|
|
|
assert torch.unique(code).numel() in [2**bits, 2**bits-1], f'bits: {bits}, method: {method}'
|
2022-11-07 00:27:48 +00:00
|
|
|
#print(method, (code==0).sum())
|
2022-11-06 19:59:37 +00:00
|
|
|
assert code.numel() == 256
|
|
|
|
for i in range(10):
|
|
|
|
|
|
|
|
values = torch.randn(1, 32, device='cuda')
|
|
|
|
values /= values.abs().max()
|
|
|
|
#values[values.abs() < 1e-6] += 1e-5
|
|
|
|
|
|
|
|
q1 = []
|
|
|
|
v1 = []
|
|
|
|
for v in values[0]:
|
|
|
|
idx = torch.abs(v-code).argmin()
|
|
|
|
q1.append(idx.item())
|
|
|
|
v1.append(code[idx].item())
|
|
|
|
|
|
|
|
q1 = torch.Tensor(q1).cuda()
|
|
|
|
v1 = torch.Tensor(v1).cuda()
|
|
|
|
|
2022-11-19 15:24:03 +00:00
|
|
|
q2, S2 = F.quantize_blockwise(values, code=code)
|
|
|
|
v2 = F.dequantize_blockwise(q2, S2)
|
2022-11-06 19:59:37 +00:00
|
|
|
|
|
|
|
idx = torch.isclose(q1.int(), q2.int())
|
2022-11-06 21:05:25 +00:00
|
|
|
err2 = torch.abs(v2-values)
|
|
|
|
abserrs.append(err2.mean().item())
|
|
|
|
relerrs.append((err2/(1e-10+values).abs()).mean().item())
|
2022-11-06 19:59:37 +00:00
|
|
|
if idx.sum():
|
|
|
|
# some weird cases
|
|
|
|
err1 = torch.abs(v1-values).mean()
|
2022-11-19 15:24:03 +00:00
|
|
|
#assert err2.mean() <= err1
|
2022-11-06 19:59:37 +00:00
|
|
|
|
|
|
|
else:
|
2023-05-07 20:34:03 +00:00
|
|
|
torch.testing.assert_close(q1, q2)
|
2022-11-07 00:27:48 +00:00
|
|
|
#print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
|
2022-11-19 15:24:03 +00:00
|
|
|
#assert False
|
2022-11-06 21:05:25 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_kbit_quantile_estimation():
|
|
|
|
for i in range(100):
|
|
|
|
data = torch.randn(1024, 1024, device='cuda')
|
|
|
|
for bits in range(2, 9):
|
|
|
|
p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits)
|
|
|
|
val1 = torch.Tensor(norm.ppf(p)).cuda()
|
|
|
|
val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
|
|
|
|
err = torch.abs(val1-val2).mean()
|
2022-11-19 15:24:03 +00:00
|
|
|
assert err < 0.038
|
|
|
|
|
|
|
|
for i in range(100):
|
|
|
|
data = torch.randn(1024, 1024, device='cuda')
|
|
|
|
for bits in range(2, 4):
|
|
|
|
total_values = 2**bits-1
|
|
|
|
p = np.linspace(0, 1, 2*total_values+1)
|
|
|
|
idx = np.arange(1, 2*total_values+1, 2)
|
|
|
|
p = p[idx]
|
|
|
|
offset = 1/(2*total_values)
|
|
|
|
p = np.linspace(offset, 1-offset, total_values)
|
|
|
|
val1 = torch.Tensor(norm.ppf(p)).cuda()
|
|
|
|
val2 = F.estimate_quantiles(data, num_quantiles=2**bits-1)
|
|
|
|
err = torch.abs(val1-val2).mean()
|
2022-11-06 21:05:25 +00:00
|
|
|
assert err < 0.035
|
2022-11-08 02:06:18 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_bench_dequantization():
|
|
|
|
a = torch.rand(1024, 1024, device='cuda').half()
|
2023-01-29 01:05:22 +00:00
|
|
|
code =F.create_fp8_map(True, 3, 0, 4).cuda()
|
|
|
|
qa, SA = F.quantize_blockwise(a, code=code)
|
|
|
|
print(qa.max())
|
2022-11-08 02:06:18 +00:00
|
|
|
|
|
|
|
max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000
|
|
|
|
#print(max_theoretical_mu)
|
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
t0 = time.time()
|
|
|
|
for i in range(100):
|
2023-02-04 22:52:04 +00:00
|
|
|
qa, SA = F.quantize_blockwise(a)
|
2022-11-08 02:06:18 +00:00
|
|
|
torch.cuda.synchronize()
|
|
|
|
#print((time.time()-t0)/1e6)
|
|
|
|
|
2023-02-04 22:52:04 +00:00
|
|
|
|
|
|
|
|
2023-07-05 02:58:31 +00:00
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
|
|
|
|
def test_fp4_quant(dtype):
|
2023-02-04 22:52:04 +00:00
|
|
|
vals = list(product([0, 1], repeat=4))
|
|
|
|
|
|
|
|
code = {}
|
|
|
|
for bits in vals:
|
|
|
|
result = 0
|
|
|
|
bias = 3
|
|
|
|
sign, e1, e2, p1 = bits
|
|
|
|
idx = sign*8 + e1*4 + e2*2 + p1*1
|
|
|
|
sign = -1.0 if sign else 1.0
|
|
|
|
exp = e1*2 + e2*1
|
|
|
|
if exp == 0:
|
|
|
|
# sub-normal
|
|
|
|
if p1 == 0: result = 0
|
|
|
|
else: result = sign*0.0625
|
|
|
|
else:
|
|
|
|
# normal
|
|
|
|
exp = 2**(-exp + bias + 1)
|
|
|
|
frac = 1.5 if p1 else 1.0
|
|
|
|
result = sign*exp*frac
|
|
|
|
code[idx] = result
|
|
|
|
|
2023-07-05 02:58:31 +00:00
|
|
|
A1 = torch.randn(1024, 1024, device='cuda', dtype=dtype)
|
2023-02-04 22:52:04 +00:00
|
|
|
qa, SA = F.quantize_fp4(A1, blocksize=64)
|
|
|
|
A2 = F.dequantize_fp4(qa, SA)
|
|
|
|
|
|
|
|
err = (A1 - A2).abs().float()
|
|
|
|
relerr = (err/A1.abs().float()).mean()
|
2023-04-02 19:42:01 +00:00
|
|
|
idx = err > 1.0
|
2023-02-04 22:52:04 +00:00
|
|
|
err = err.mean()
|
|
|
|
|
2023-07-05 02:58:31 +00:00
|
|
|
assert A2.dtype == dtype
|
2023-02-05 05:11:21 +00:00
|
|
|
assert err.item() < 0.1
|
|
|
|
assert relerr.item() < 0.28
|
2023-02-04 22:52:04 +00:00
|
|
|
|
|
|
|
|
2023-04-02 23:10:35 +00:00
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
|
|
|
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
|
|
|
|
def test_4bit_compressed_stats(quant_type):
|
2023-04-01 23:10:18 +00:00
|
|
|
for blocksize in [128, 64]:
|
|
|
|
errs1 = []
|
|
|
|
errs2 = []
|
2023-04-02 23:10:35 +00:00
|
|
|
for i in range(10):
|
2023-04-01 23:10:18 +00:00
|
|
|
A1 = torch.randn(1024, 1024, device='cuda').half()
|
2023-04-03 18:00:12 +00:00
|
|
|
q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
|
|
|
|
q3, SA3= F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
|
|
|
|
A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type)
|
|
|
|
A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type)
|
2023-04-01 23:10:18 +00:00
|
|
|
|
|
|
|
|
|
|
|
err = (A1 - A2).abs().float()
|
|
|
|
relerr = (err/(A1.abs().float()+1e-15)).mean()
|
|
|
|
err = err.mean()
|
|
|
|
|
2023-04-03 18:00:12 +00:00
|
|
|
errs1.append(err.item())
|
|
|
|
|
2023-04-01 23:10:18 +00:00
|
|
|
|
|
|
|
assert err.item() < 0.11
|
|
|
|
assert relerr.item() < 0.28
|
|
|
|
|
|
|
|
err = (A1 - A3).abs().float()
|
|
|
|
relerr = (err/(A1.abs().float()+1e-15)).mean()
|
|
|
|
err = err.mean()
|
|
|
|
|
2023-04-03 18:00:12 +00:00
|
|
|
errs2.append(err.item())
|
2023-04-01 23:10:18 +00:00
|
|
|
|
|
|
|
assert err.item() < 0.11
|
|
|
|
assert relerr.item() < 0.28
|
|
|
|
|
2023-04-03 18:00:12 +00:00
|
|
|
#print(sum(errs1)/len(errs1), blocksize, quant_type)
|
|
|
|
#print(sum(errs2)/len(errs2), blocksize, quant_type)
|
2023-04-01 23:10:18 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2023-04-02 23:10:35 +00:00
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
2023-05-31 03:07:05 +00:00
|
|
|
#@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
|
|
|
|
@pytest.mark.parametrize("quant_type", ['nf4'])
|
2023-04-03 18:00:12 +00:00
|
|
|
def test_bench_4bit_dequant(quant_type):
|
2023-02-04 22:52:04 +00:00
|
|
|
blocksize = 256
|
|
|
|
a = torch.rand(1024*12*4, 1024*12, device='cuda').half()
|
2023-04-03 18:00:12 +00:00
|
|
|
qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type)
|
2023-02-04 22:52:04 +00:00
|
|
|
|
|
|
|
input_size = a.numel()/2
|
|
|
|
output_size = a.numel()*2
|
|
|
|
num_bytes = input_size+output_size
|
|
|
|
GB = num_bytes/1e9
|
|
|
|
max_theoretical_s = GB/768
|
2023-02-05 06:00:04 +00:00
|
|
|
#print(max_theoretical_s*1e6)
|
2023-02-04 22:52:04 +00:00
|
|
|
b = torch.randn(128, 1024*12, device='cuda').half()
|
|
|
|
|
2023-05-31 03:07:05 +00:00
|
|
|
iters = 100
|
2023-02-04 22:52:04 +00:00
|
|
|
torch.cuda.synchronize()
|
|
|
|
t0 = time.time()
|
|
|
|
for i in range(iters):
|
2023-04-03 18:00:12 +00:00
|
|
|
F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
|
2023-02-04 22:52:04 +00:00
|
|
|
#b.copy_(a)
|
|
|
|
torch.cuda.synchronize()
|
2023-02-05 06:00:04 +00:00
|
|
|
#print((time.time()-t0)/iters*1e6)
|
|
|
|
|
|
|
|
#torch.cuda.synchronize()
|
|
|
|
#t0 = time.time()
|
|
|
|
#for i in range(iters):
|
|
|
|
# torch.matmul(b, a.t())
|
|
|
|
#torch.cuda.synchronize()
|
|
|
|
#print((time.time()-t0)/iters*1e6)
|
2023-04-02 21:42:45 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_normal_map_tree():
|
|
|
|
code = F.create_normal_map()
|
|
|
|
values =code[:8].tolist() + code[-8:].tolist()
|
|
|
|
num_pivots = 1
|
2023-04-02 23:10:35 +00:00
|
|
|
print(values)
|
2023-04-02 21:42:45 +00:00
|
|
|
while num_pivots <16:
|
|
|
|
idx = list(range(16//num_pivots//2, 16, 16//num_pivots))
|
|
|
|
print(idx)
|
|
|
|
num_pivots *= 2
|
|
|
|
pivots = []
|
|
|
|
for i in idx:
|
|
|
|
pivots.append((values[i-1]+values[i])/2)
|
|
|
|
print(pivots)
|
|
|
|
|
2023-04-27 00:12:34 +00:00
|
|
|
|
2023-04-30 04:52:47 +00:00
|
|
|
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
2023-04-29 01:26:52 +00:00
|
|
|
def test_cutlass3_gemm(dtype):
|
2023-05-02 23:15:38 +00:00
|
|
|
debug = True
|
|
|
|
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
2023-05-02 15:58:59 +00:00
|
|
|
#for dim in [4096, 5120, 6656, 8192]:
|
2023-05-02 23:15:38 +00:00
|
|
|
for dim in [4096]:
|
|
|
|
#for dim in [128+1]:
|
2023-05-02 14:50:32 +00:00
|
|
|
errs = []
|
|
|
|
relerrs = []
|
|
|
|
max_err = 0
|
|
|
|
max_relerr = 0
|
|
|
|
for i in range(100):
|
2023-05-02 23:15:38 +00:00
|
|
|
A = torch.randn(1, dim, dtype=dtype, device='cuda')
|
2023-05-02 15:58:59 +00:00
|
|
|
B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim)
|
2023-05-02 23:15:38 +00:00
|
|
|
#B = torch.randn(1, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
|
2023-05-02 14:50:32 +00:00
|
|
|
|
|
|
|
#print('')
|
|
|
|
#print(A)
|
|
|
|
#print(B.t())
|
2023-05-02 19:10:32 +00:00
|
|
|
#A[:, :-1] = 0
|
|
|
|
#B[:, :-1] = 0
|
2023-05-02 14:50:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
C1 = torch.matmul(A, B.t())
|
|
|
|
C2 = F.cutlass3_gemm(A, B.t())
|
|
|
|
|
|
|
|
# tensor cores are non-deterministic
|
|
|
|
# so we need to analyze errors around the mean
|
|
|
|
# to test our implementation
|
|
|
|
err = torch.abs(C1-C2)
|
|
|
|
mag = torch.abs(C1)+1e-8
|
|
|
|
relerr = err/mag
|
|
|
|
max_err = max(err.max(), max_err)
|
|
|
|
max_relerr = max(relerr.max(), max_relerr)
|
|
|
|
err = err.mean().item()
|
|
|
|
relerr = relerr.mean().item()
|
|
|
|
|
|
|
|
errs.append(err)
|
|
|
|
relerrs.append(relerr)
|
|
|
|
|
2023-05-02 23:15:38 +00:00
|
|
|
#if not debug and err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
|
2023-05-02 14:50:32 +00:00
|
|
|
# print('')
|
2023-05-02 19:10:32 +00:00
|
|
|
# print(i, err, relerr)
|
2023-05-02 14:50:32 +00:00
|
|
|
# print(A.flatten()[-6:])
|
|
|
|
# print(B.flatten()[-6:])
|
|
|
|
# out = A.flatten()[-6:]*B.flatten()[-6:]
|
|
|
|
# print(out)
|
|
|
|
# print(out[:-1].sum())
|
|
|
|
# print('='*80)
|
|
|
|
# print(C1.flatten()[-6:])
|
|
|
|
# print(C2.flatten()[-6:])
|
|
|
|
# #assert False, 'ERROR'
|
|
|
|
|
2023-05-02 15:58:59 +00:00
|
|
|
c = int(C1.numel()*0.0014*(dim/256))+1
|
2023-05-02 14:53:29 +00:00
|
|
|
|
2023-05-02 23:15:38 +00:00
|
|
|
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=not debug)
|
2023-05-02 15:58:59 +00:00
|
|
|
#print(c/math.sqrt(dim))
|
2023-05-02 14:50:32 +00:00
|
|
|
print('')
|
|
|
|
print(dim, sum(errs)/len(errs)/math.sqrt(dim))
|
|
|
|
print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
|
|
|
|
print(dim, (max_err.item(), max_relerr.item()))
|
2023-04-27 00:12:34 +00:00
|
|
|
|
2023-04-30 04:52:47 +00:00
|
|
|
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
2023-07-05 02:58:31 +00:00
|
|
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=['fp16', 'bf16'])
|
|
|
|
#@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16'])
|
2023-04-30 04:52:47 +00:00
|
|
|
def test_gemm_4bit(dtype):
|
2023-07-04 01:45:38 +00:00
|
|
|
print('')
|
2023-07-08 21:27:12 +00:00
|
|
|
#for dim in [64, 128, 256, 512, 1024, 2048, 4096]:
|
|
|
|
for dim in [4096]:
|
2023-05-02 23:15:38 +00:00
|
|
|
errs = []
|
|
|
|
relerrs = []
|
|
|
|
max_err = 0
|
|
|
|
max_relerr = 0
|
2023-07-05 02:58:31 +00:00
|
|
|
|
2023-07-04 22:20:10 +00:00
|
|
|
for i in range(100):
|
2023-05-02 23:15:38 +00:00
|
|
|
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
|
|
|
|
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
|
|
|
|
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
|
|
|
|
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
|
2023-07-04 22:20:10 +00:00
|
|
|
A = torch.randn(1, dim, dtype=dtype, device='cuda')
|
|
|
|
B = torch.randn(4*dim, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
|
2023-07-04 01:45:38 +00:00
|
|
|
#B = torch.randn(1, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim)
|
2023-05-02 23:15:38 +00:00
|
|
|
|
|
|
|
#print('')
|
|
|
|
#print(A)
|
|
|
|
#print(B.t())
|
|
|
|
#A[:, :-1] = 0
|
|
|
|
#B[:, :-1] = 0
|
2023-07-04 01:45:38 +00:00
|
|
|
#A.flatten()[:-1] = 0
|
|
|
|
#B.flatten()[:-1] = 0
|
2023-04-30 04:52:47 +00:00
|
|
|
|
2023-05-02 23:15:38 +00:00
|
|
|
qB, state = F.quantize_nf4(B)
|
|
|
|
F.dequantize_nf4(qB, state)
|
2023-04-30 04:52:47 +00:00
|
|
|
|
2023-07-05 02:58:31 +00:00
|
|
|
#C2 = bnb.matmul_4bit(A, qB.t(), state)
|
2023-05-02 23:15:38 +00:00
|
|
|
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
|
2023-07-05 02:58:31 +00:00
|
|
|
C1 = torch.matmul(A, B.t())
|
2023-04-30 04:52:47 +00:00
|
|
|
|
2023-07-04 01:45:38 +00:00
|
|
|
#print(state)
|
|
|
|
#print(qB)
|
|
|
|
|
|
|
|
#print('')
|
|
|
|
#print(A)
|
|
|
|
#print(B)
|
|
|
|
#print('='*89)
|
|
|
|
#print(C1)
|
|
|
|
#print(C2)
|
|
|
|
#print(C3)
|
2023-05-31 03:07:05 +00:00
|
|
|
|
|
|
|
#print(C1.shape, C2.shape)
|
2023-04-30 04:52:47 +00:00
|
|
|
|
2023-05-02 23:15:38 +00:00
|
|
|
# tensor cores are non-deterministic
|
|
|
|
# so we need to analyze errors around the mean
|
|
|
|
# to test our implementation
|
2023-07-05 02:58:31 +00:00
|
|
|
err = torch.abs(C1-C2).float()
|
|
|
|
mag = torch.abs(C1).float()+1e-5
|
2023-05-02 23:15:38 +00:00
|
|
|
relerr = err/mag
|
|
|
|
max_err = max(err.max(), max_err)
|
|
|
|
max_relerr = max(relerr.max(), max_relerr)
|
|
|
|
err = err.mean().item()
|
|
|
|
relerr = relerr.mean().item()
|
2023-07-04 01:45:38 +00:00
|
|
|
#print(err)
|
2023-05-02 23:15:38 +00:00
|
|
|
|
|
|
|
errs.append(err)
|
|
|
|
relerrs.append(relerr)
|
|
|
|
|
|
|
|
c = int(C1.numel()*0.0014*(dim/256))+1
|
2023-04-30 04:52:47 +00:00
|
|
|
|
2023-05-02 23:15:38 +00:00
|
|
|
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
|
2023-07-05 02:58:31 +00:00
|
|
|
#print('')
|
|
|
|
#print(dim, sum(errs)/len(errs)/math.sqrt(dim))
|
|
|
|
#print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
|
|
|
|
#print(dim, (max_err.item(), max_relerr.item()))
|
|
|
|
#print(sum(errs)/len(errs)/math.sqrt(dim) , 0.00015)
|
|
|
|
#print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , 0.0015)
|
2023-07-08 21:27:12 +00:00
|
|
|
#assert sum(errs)/len(errs)/math.sqrt(dim) < 0.011
|
|
|
|
#assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.15
|
2023-04-27 00:12:34 +00:00
|
|
|
|
2023-05-07 20:34:03 +00:00
|
|
|
@pytest.mark.skip("Row scale has some bugs for ampere")
|
2023-05-06 18:14:06 +00:00
|
|
|
def test_managed():
|
|
|
|
n = 32*10
|
|
|
|
A = F.get_paged(n, n, dtype=torch.float32)
|
|
|
|
B = F.get_paged(n, n, dtype=torch.uint8)
|
|
|
|
B2 = F.get_paged(n, n, dtype=torch.float32)
|
|
|
|
assert A.is_paged
|
|
|
|
assert B.is_paged
|
|
|
|
assert A.page_deviceid==0
|
|
|
|
assert B.page_deviceid==0
|
|
|
|
F.fill(A, 17.0)
|
|
|
|
F.fill(B, 17)
|
|
|
|
F.fill(B2, 2)
|
|
|
|
assert (A==17).sum().item() == n*n
|
|
|
|
assert (B==17).sum().item() == n*n
|
|
|
|
C = A*B.float()
|
|
|
|
assert (C==289).sum().item() == n*n
|
|
|
|
F._mul(A, B2)
|
|
|
|
F._mul(A, B2)
|
|
|
|
F._mul(A, B2)
|
|
|
|
assert (A==17*(2**3)).sum().item() == n*n
|
|
|
|
# F.prefetch_tensor(A)
|
|
|
|
# F.prefetch_tensor(B)
|
|
|
|
|
|
|
|
|
|
|
|
# F.fill(B2, 17.0)
|
|
|
|
# F._mul(A, B2)
|
|
|
|
|
|
|
|
# F.prefetch_tensor(A, to_cpu=True)
|
|
|
|
# F.prefetch_tensor(B, to_cpu=True)
|
|
|
|
# F.prefetch_tensor(B2, to_cpu=True)
|
|
|
|
# torch.cuda.synchronize()
|
|
|
|
|
|
|
|
# assert (A==17).sum().item() == n*n
|
|
|
|
|
2023-05-07 20:34:03 +00:00
|
|
|
# torch.testing.assert_close(A, torch.ones(A.shape)*289)
|