bitsandbytes-rocm/quicktest.py
2022-08-01 09:32:47 -07:00

113 lines
3.9 KiB
Python

from itertools import product
import torch
import bitsandbytes as bnb
import bitsandbytes.functional as F
def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
k = 25
for i in range(k):
if dims == 2:
A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(
torch.int8
)
elif dims == 3:
A = torch.randint(
-128, 127, size=(dim1, dim2, dim3), device="cuda"
).to(torch.int8)
B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(
torch.int8
)
C1 = torch.matmul(A.float(), B.t().float())
A2, SA = F.transform(A, "col32")
B2, SB = F.transform(B, "colx")
if dims == 2:
C2, SC = F.transform(
torch.zeros(
A.shape[0], B.shape[0], dtype=torch.int32, device="cuda"
),
"col32",
)
else:
C2, SC = F.transform(
torch.zeros(
A.shape[0],
A.shape[1],
B.shape[0],
dtype=torch.int32,
device="cuda",
),
"col32",
)
F.igemmlt(A2, B2, C2, SA, SB, SC)
C3, S = F.transform(C2, "row", state=SC)
# torch.testing.assert_allclose(C1, C3.float())
# print(C1)
# print(C2)
# print(C3)
allclose = torch.allclose(C1, C3.float())
if allclose:
print(C1)
print(C2)
print(C3)
## transposed
# A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
# if dims == 2:
# B = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
# C1 = torch.matmul(A.float(), B.float().t())
# elif dims == 3:
# B = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
# C1 = torch.matmul(B.float(), A.t().float())
# C1 = C1.permute([2, 0, 1])
# A2, SA = F.transform(A, 'col32')
# B2, SB = F.transform(B, 'colx')
# if dims == 2:
# C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
# else:
# C2 = torch.zeros(A.shape[0], B.shape[0], B.shape[1], dtype=torch.int32, device='cuda')
# state = (C2.shape, 'row', A.shape[0])
# C2, SC = F.transform(C2, 'col32', state=state)
# F.igemmlt(A2, B2, C2, SA, SB, SC)
# C3, S = F.transform(C2, 'row', state=SC, ld=[0])
# torch.testing.assert_allclose(C1, C3.float())
## weight update
# if dims == 3:
# A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
# B = torch.randint(-128, 127, size=(dim1, dim2, dim4), device='cuda').to(torch.int8)
# C1 = torch.matmul(B.view(-1, B.shape[-1]).t().float(), A.view(-1, A.shape[-1]).float())
# A2, SA = F.transform(A.view(-1, A.shape[-1]).t().contiguous(), 'colx')
# B2, SB = F.transform(B.view(-1, B.shape[-1]).t().contiguous(), 'col32')
# C2 = torch.zeros(B.shape[-1], A.shape[-1], dtype=torch.int32, device='cuda')
# C2, SC = F.transform(C2, 'col32')
# F.igemmlt(B2, A2, C2, SB, SA, SC)
# C3, S = F.transform(C2, 'row', state=SC)
# torch.testing.assert_allclose(C1, C3.float())
dims = (2, 3)
ldb = [0]
n = 2
dim1 = torch.randint(1, 256, size=(n,)).tolist()
dim2 = torch.randint(32, 512, size=(n,)).tolist()
dim3 = torch.randint(32, 1024, size=(n,)).tolist()
dim4 = torch.randint(32, 1024, size=(n,)).tolist()
values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
for ldb in range(32, 4096, 32):
# for ldb in [None]:
val = test_igemmlt(2, 2, 2, 2, 2, ldb)
if val:
print(val, ldb)
else:
print("nope", ldb)
# for val in values:
# test_igemmlt(*val)