bitsandbytes-rocm/tests/test_autograd.py

432 lines
15 KiB
Python
Raw Normal View History

2022-08-16 19:00:54 +00:00
from itertools import product, permutations
2022-07-22 21:41:05 +00:00
import pytest
2022-07-22 21:41:05 +00:00
import torch
import bitsandbytes as bnb
2022-07-22 21:41:05 +00:00
n = 1
k = 25
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
2022-07-22 21:41:05 +00:00
funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)]
str_funcs = ["bmm", "matmul"]
2022-07-22 21:41:05 +00:00
req_grad = [(False, False), (True, False), (True, True), (False, True)]
req_grad_str = ["FF", "TF", "TT", "FT"]
2022-07-22 21:41:05 +00:00
transpose = [(False, False), (False, True), (True, True), (True, False)]
str_transpose = ["FF", "FT", "TT", "TF"]
2022-07-22 21:41:05 +00:00
dtype = [torch.float32, torch.float16]
values = list(
product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)
)
str_values = list(
product(
dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose
)
)
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format(
*vals
)
for vals in str_values
]
@pytest.mark.parametrize(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose",
values,
ids=names,
)
2022-07-22 21:41:05 +00:00
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
2022-08-23 20:59:34 +00:00
if not torch.cuda.is_available(): pytest.skip('No GPU found.')
if dim2 > 0:
dim2 = dim2 - (dim2 % 16)
2022-07-22 21:41:05 +00:00
dim3 = dim3 - (dim3 % 16)
dim4 = dim4 - (dim4 % 16)
for i in range(k):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0])
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
target = torch.randn(
size=(dim2, dim4), device="cuda", requires_grad=req_grad[1]
)
2022-07-22 21:41:05 +00:00
torch.nn.init.xavier_uniform_(B)
if not transpose[0] and not transpose[1]:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B)
elif not transpose[0] and transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B.t())
elif transpose[0] and not transpose[1]:
out_torch = funcs[0](A.t(), B)
out_bnb = funcs[1](A.t(), B)
elif transpose[0] and transpose[1]:
out_torch = funcs[0](A.t(), B.t())
out_bnb = funcs[1](A.t(), B.t())
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.0175
2022-07-22 21:41:05 +00:00
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx == 0).sum().item() < n * 0.001
2022-07-22 21:41:05 +00:00
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss(
out_torch, target
).mean()
2022-07-22 21:41:05 +00:00
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if req_grad[0]:
torch.testing.assert_allclose(
gradA1, gradA2, atol=0.015, rtol=0.1
)
2022-07-22 21:41:05 +00:00
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1
2022-07-22 21:41:05 +00:00
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02
torch.testing.assert_allclose(
gradB1, gradB2, atol=0.18, rtol=0.3
)
2022-07-22 21:41:05 +00:00
# batched matrix multiply
if funcs[0] in [torch.bmm, torch.matmul]:
A = torch.randn(
size=(dim1, dim2, dim3),
device="cuda",
requires_grad=req_grad[0],
)
B = torch.randn(
size=(dim1, dim3, dim4),
device="cuda",
requires_grad=req_grad[1],
)
target = torch.randn(
size=(dim1, dim2, dim4),
device="cuda",
requires_grad=req_grad[1],
)
2022-07-22 21:41:05 +00:00
torch.nn.init.xavier_uniform_(B)
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B)
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.01
torch.testing.assert_allclose(
out_bnb, out_torch, atol=0.027, rtol=0.2
)
2022-07-22 21:41:05 +00:00
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss(
out_torch, target
).mean()
2022-07-22 21:41:05 +00:00
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if req_grad[0]:
torch.testing.assert_allclose(
gradA1, gradA2, atol=0.015, rtol=0.1
)
2022-07-22 21:41:05 +00:00
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1
2022-07-22 21:41:05 +00:00
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02
2022-07-22 21:41:05 +00:00
if funcs[0] in [torch.matmul]:
dim1 = dim1 - (dim1 % 16)
A = torch.randn(
size=(dim1, dim2, dim3),
device="cuda",
requires_grad=req_grad[0],
)
2022-07-22 21:41:05 +00:00
dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
target = torch.randn(
size=(dim1, dim2, dim4),
device="cuda",
requires_grad=req_grad[1],
)
2022-07-22 21:41:05 +00:00
torch.nn.init.xavier_uniform_(B)
if transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B.t())
else:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B)
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.0175
2022-07-22 21:41:05 +00:00
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx == 0).sum().item() < n * 0.001
2022-07-22 21:41:05 +00:00
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss(
out_torch, target
).mean()
2022-07-22 21:41:05 +00:00
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if req_grad[0]:
torch.testing.assert_allclose(
gradA1, gradA2, atol=0.015, rtol=0.1
)
2022-07-22 21:41:05 +00:00
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1
2022-07-22 21:41:05 +00:00
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02
2022-07-22 21:41:05 +00:00
n = 1
k = 3
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
2022-07-22 21:41:05 +00:00
dim2.append(0)
2022-07-22 21:41:05 +00:00
decomp = [0.0, 6.0]
funcs = [(torch.matmul, bnb.matmul)]
str_funcs = ["matmul"]
2022-07-22 21:41:05 +00:00
req_grad = [(False, False), (True, False), (True, True), (False, True)]
2022-08-16 19:00:54 +00:00
req_grad = list(product([True, False], repeat=3))
req_grad_str = []
for c in req_grad:
strval = ''
for v in c:
if v == True: strval += 'T'
else: strval += 'F'
req_grad_str.append(strval)
2022-07-22 21:41:05 +00:00
transpose = [(False, True), (False, False)]
str_transpose = ["NT", "NN"]
2022-09-20 03:51:25 +00:00
dtype = [torch.float16, torch.bfloat16, torch.float32]
2022-07-22 21:41:05 +00:00
has_fp16_weights = [True, False]
2022-08-16 19:00:54 +00:00
has_bias = [True, False]
values = list(
product(
dim1,
dim2,
dim3,
dim4,
funcs,
dtype,
req_grad,
transpose,
decomp,
has_fp16_weights,
2022-08-16 19:00:54 +00:00
has_bias
)
)
str_values = list(
product(
dim1,
dim2,
dim3,
dim4,
str_funcs,
dtype,
req_grad_str,
str_transpose,
decomp,
has_fp16_weights,
2022-08-16 19:00:54 +00:00
has_bias
)
)
2022-08-16 19:00:54 +00:00
names = ["dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}_has_bias_{10}".format(*vals) for vals in str_values]
@pytest.mark.parametrize(
2022-08-16 19:00:54 +00:00
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias",
values,
ids=names,
)
def test_matmullt(
dim1,
dim2,
dim3,
dim4,
funcs,
dtype,
req_grad,
transpose,
decomp,
has_fp16_weights,
2022-08-16 19:00:54 +00:00
has_bias
):
2022-08-23 20:59:34 +00:00
if not torch.cuda.is_available(): pytest.skip('No GPU found.')
2022-07-22 21:41:05 +00:00
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
2022-08-16 19:00:54 +00:00
if has_bias == False:
req_grad = list(req_grad)
req_grad[2] = False
2022-07-22 21:41:05 +00:00
for i in range(k):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(
size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype
)
2022-07-22 21:41:05 +00:00
if decomp == 6.0:
with torch.no_grad():
A[:, outlier_dim] = 6.0
B = torch.randn(
size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype
)
target = torch.randn(
size=(dim2, dim4),
device="cuda",
requires_grad=req_grad[1],
dtype=dtype,
)
2022-08-16 19:00:54 +00:00
bias = None
bias2 = None
if has_bias:
bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2])
bias2 = bias.clone()
2022-07-22 21:41:05 +00:00
torch.nn.init.xavier_uniform_(B)
B2 = B.clone()
state = bnb.MatmulLtState()
state.threshold = decomp
state.has_fp16_weights = has_fp16_weights
if not has_fp16_weights:
if not transpose[0] and not transpose[1]:
B2 = B2.t().contiguous()
(
state.CB,
CBt,
state.SCB,
SCBt,
coo_tensorB,
2022-09-17 20:12:58 +00:00
) = bnb.functional.double_quant(B2.to(torch.float16))
2022-07-22 21:41:05 +00:00
B2 = state.CB
if not transpose[0] and transpose[1]:
out_torch = funcs[0](A, B.t())
2022-08-16 19:00:54 +00:00
out_bnb = funcs[1](A, B2, state=state, bias=bias2)
2022-07-22 21:41:05 +00:00
elif not transpose[0] and not transpose[1]:
out_torch = funcs[0](A, B)
2022-08-16 19:00:54 +00:00
out_bnb = funcs[1](A, B2.t(), state=state, bias=bias2)
if has_bias:
out_torch += bias
2022-07-22 21:41:05 +00:00
2022-09-17 20:22:04 +00:00
assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"
2022-09-17 20:12:58 +00:00
2022-07-22 21:41:05 +00:00
n = out_bnb.numel()
err = torch.abs(out_bnb - out_torch).mean().item()
# print(f'abs error {err:.4f}')
2022-09-17 21:35:03 +00:00
2022-07-22 21:41:05 +00:00
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
2022-09-17 21:42:23 +00:00
assert (idx == 0).sum().item() <= n * (0.0175 if dtype == torch.float16 else 0.021)
2022-07-22 21:41:05 +00:00
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
2022-08-04 16:16:00 +00:00
assert (idx == 0).sum().item() <= n * 0.001
2022-07-22 21:41:05 +00:00
if has_fp16_weights:
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(
out_bnb, target
).mean()
2022-07-22 21:41:05 +00:00
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
2022-08-16 19:00:54 +00:00
if has_bias:
gradBias1 = bias.grad
bias.grad = None
2022-07-22 21:41:05 +00:00
loss_torch = torch.nn.functional.mse_loss(
out_torch, target
).mean()
2022-07-22 21:41:05 +00:00
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
2022-08-16 19:00:54 +00:00
if has_bias:
gradBias2 = bias.grad
bias.grad = None
2022-07-22 21:41:05 +00:00
if req_grad[0]:
torch.testing.assert_allclose(
gradA1, gradA2, atol=0.015, rtol=0.1
)
2022-07-22 21:41:05 +00:00
if req_grad[1]:
n = gradB1.numel()
if dim2 > 0:
assert torch.abs(gradB1).sum() > 0.0
assert torch.abs(gradB2).sum() > 0.0
else:
assert torch.abs(gradB1).sum() == 0.0
assert torch.abs(gradB2).sum() == 0.0
2022-07-22 21:41:05 +00:00
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
2022-08-04 16:16:00 +00:00
assert (idx == 0).sum().item() <= n * 0.1
2022-07-22 21:41:05 +00:00
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
2022-08-04 16:16:00 +00:00
assert (idx == 0).sum().item() <= n * 0.02
torch.testing.assert_allclose(
gradB1, gradB2, atol=0.18, rtol=0.3
)
2022-08-16 19:00:54 +00:00
if req_grad[2]:
2022-09-17 20:44:28 +00:00
torch.testing.assert_allclose(gradBias1, gradBias2)