Add force_no_igemmlt to test params
This commit is contained in:
parent
24609b66af
commit
dcecbb26ca
|
@ -69,9 +69,9 @@ def test_linear_no_igemmlt():
|
|||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
||||
@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda",
|
||||
list(product([False, True], [False, True], [False, True])))
|
||||
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda):
|
||||
@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt",
|
||||
list(product([False, True], [False, True], [False, True], [False, True])))
|
||||
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt):
|
||||
linear = torch.nn.Linear(32, 96)
|
||||
x = torch.randn(3, 32, dtype=torch.half)
|
||||
|
||||
|
@ -82,6 +82,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
|
|||
has_fp16_weights=has_fp16_weights,
|
||||
threshold=6.0,
|
||||
)
|
||||
if force_no_igemmlt:
|
||||
linear_custom.state.force_no_igemmlt = True
|
||||
|
||||
linear_custom.weight = bnb.nn.Int8Params(
|
||||
linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights
|
||||
)
|
||||
|
@ -118,6 +121,8 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
|
|||
has_fp16_weights=has_fp16_weights,
|
||||
threshold=6.0,
|
||||
)
|
||||
if force_no_igemmlt:
|
||||
new_linear_custom.state.force_no_igemmlt = True
|
||||
|
||||
if deserialize_before_cuda:
|
||||
with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError):
|
||||
|
|
Loading…
Reference in New Issue
Block a user