91 lines
3.6 KiB
Python
91 lines
3.6 KiB
Python
|
import torch
|
||
|
import bitsandbytes as bnb
|
||
|
import bitsandbytes.functional as F
|
||
|
|
||
|
from itertools import product
|
||
|
|
||
|
def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
|
||
|
k = 25
|
||
|
for i in range(k):
|
||
|
if dims == 2:
|
||
|
A = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
|
||
|
elif dims == 3:
|
||
|
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
|
||
|
B = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
|
||
|
C1 = torch.matmul(A.float(), B.t().float())
|
||
|
|
||
|
A2, SA = F.transform(A, 'col32')
|
||
|
B2, SB = F.transform(B, 'colx')
|
||
|
if dims == 2:
|
||
|
C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
|
||
|
else:
|
||
|
C2, SC = F.transform(torch.zeros(A.shape[0], A.shape[1], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
|
||
|
F.igemmlt(A2, B2, C2, SA, SB, SC)
|
||
|
C3, S = F.transform(C2, 'row', state=SC)
|
||
|
#torch.testing.assert_allclose(C1, C3.float())
|
||
|
#print(C1)
|
||
|
#print(C2)
|
||
|
#print(C3)
|
||
|
allclose = torch.allclose(C1, C3.float())
|
||
|
if allclose:
|
||
|
print(C1)
|
||
|
print(C2)
|
||
|
print(C3)
|
||
|
|
||
|
## transposed
|
||
|
#A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
|
||
|
#if dims == 2:
|
||
|
# B = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
|
||
|
# C1 = torch.matmul(A.float(), B.float().t())
|
||
|
#elif dims == 3:
|
||
|
# B = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
|
||
|
# C1 = torch.matmul(B.float(), A.t().float())
|
||
|
# C1 = C1.permute([2, 0, 1])
|
||
|
|
||
|
#A2, SA = F.transform(A, 'col32')
|
||
|
#B2, SB = F.transform(B, 'colx')
|
||
|
#if dims == 2:
|
||
|
# C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
|
||
|
#else:
|
||
|
# C2 = torch.zeros(A.shape[0], B.shape[0], B.shape[1], dtype=torch.int32, device='cuda')
|
||
|
# state = (C2.shape, 'row', A.shape[0])
|
||
|
# C2, SC = F.transform(C2, 'col32', state=state)
|
||
|
#F.igemmlt(A2, B2, C2, SA, SB, SC)
|
||
|
#C3, S = F.transform(C2, 'row', state=SC, ld=[0])
|
||
|
#torch.testing.assert_allclose(C1, C3.float())
|
||
|
|
||
|
## weight update
|
||
|
#if dims == 3:
|
||
|
# A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
|
||
|
# B = torch.randint(-128, 127, size=(dim1, dim2, dim4), device='cuda').to(torch.int8)
|
||
|
# C1 = torch.matmul(B.view(-1, B.shape[-1]).t().float(), A.view(-1, A.shape[-1]).float())
|
||
|
|
||
|
# A2, SA = F.transform(A.view(-1, A.shape[-1]).t().contiguous(), 'colx')
|
||
|
# B2, SB = F.transform(B.view(-1, B.shape[-1]).t().contiguous(), 'col32')
|
||
|
# C2 = torch.zeros(B.shape[-1], A.shape[-1], dtype=torch.int32, device='cuda')
|
||
|
# C2, SC = F.transform(C2, 'col32')
|
||
|
# F.igemmlt(B2, A2, C2, SB, SA, SC)
|
||
|
# C3, S = F.transform(C2, 'row', state=SC)
|
||
|
# torch.testing.assert_allclose(C1, C3.float())
|
||
|
|
||
|
|
||
|
dims = (2, 3)
|
||
|
ldb = [0]
|
||
|
|
||
|
n = 2
|
||
|
dim1 = torch.randint(1,256, size=(n,)).tolist()
|
||
|
dim2 = torch.randint(32,512, size=(n,)).tolist()
|
||
|
dim3 = torch.randint(32,1024, size=(n,)).tolist()
|
||
|
dim4 = torch.randint(32,1024, size=(n,)).tolist()
|
||
|
values = list(product(dim1,dim2,dim3,dim4,dims, ldb))
|
||
|
|
||
|
for ldb in range(32, 4096, 32):
|
||
|
#for ldb in [None]:
|
||
|
val = test_igemmlt(2, 2, 2, 2, 2, ldb)
|
||
|
if val:
|
||
|
print(val, ldb)
|
||
|
else:
|
||
|
print('nope', ldb)
|
||
|
#for val in values:
|
||
|
#test_igemmlt(*val)
|