Add force_no_igemmlt to test params

This commit is contained in:
Max Ryabinin 2023-03-22 00:28:49 +01:00
parent 24609b66af
commit dcecbb26ca

View File

@ -69,9 +69,9 @@ def test_linear_no_igemmlt():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda",
list(product([False, True], [False, True], [False, True])))
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda):
@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt",
list(product([False, True], [False, True], [False, True], [False, True])))
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt):
linear = torch.nn.Linear(32, 96)
x = torch.randn(3, 32, dtype=torch.half)
@ -82,6 +82,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
has_fp16_weights=has_fp16_weights,
threshold=6.0,
)
if force_no_igemmlt:
linear_custom.state.force_no_igemmlt = True
linear_custom.weight = bnb.nn.Int8Params(
linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights
)
@ -118,6 +121,8 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
has_fp16_weights=has_fp16_weights,
threshold=6.0,
)
if force_no_igemmlt:
new_linear_custom.state.force_no_igemmlt = True
if deserialize_before_cuda:
with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError):