From ac3ab281e39cbc514ebef08823482d5b0cba42c1 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sat, 25 Feb 2023 06:01:04 +0100 Subject: [PATCH] Handle more cases in test_linear_serialization --- tests/test_linear8bitlt.py | 53 ++++++++++++++++++++++++++++---------- 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 8edee58..1aafe3d 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -1,4 +1,7 @@ -from copy import deepcopy +import os +from contextlib import nullcontext +from itertools import product +from tempfile import TemporaryDirectory import pytest import torch @@ -66,10 +69,11 @@ def test_linear_no_igemmlt(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") -@pytest.mark.parametrize("has_fp16_weights", [False, True]) -def test_linear_serialization(has_fp16_weights): - linear = torch.nn.Linear(16, 32) - x = torch.randn(3, 16, dtype=torch.half) +@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda", + list(product([False, True], [False, True], [False, True]))) +def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda): + linear = torch.nn.Linear(32, 96) + x = torch.randn(3, 32, dtype=torch.half) linear_custom = Linear8bitLt( linear.in_features, @@ -78,19 +82,34 @@ def test_linear_serialization(has_fp16_weights): has_fp16_weights=has_fp16_weights, threshold=6.0, ) - linear_custom.state.force_no_igemmlt = True linear_custom.weight = bnb.nn.Int8Params( - linear.weight.data.clone(), requires_grad=False, has_fp16_weights=has_fp16_weights - ).to(linear.weight.dtype) + linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights + ) linear_custom.bias = linear.bias linear_custom = linear_custom.cuda() + if serialize_before_forward: + state_dict_8bit = linear_custom.state_dict() + x_first = x.clone().cuda().requires_grad_(True) fx_first = linear_custom(x_first).float() grad_proj = torch.randn_like(fx_first) (fx_first * grad_proj).mean().backward() - state_dict = deepcopy(linear_custom.state_dict()) + if not serialize_before_forward: + state_dict_8bit = linear_custom.state_dict() + + with TemporaryDirectory() as tmpdir: + state_path_8bit = os.path.join(tmpdir, "state_8bit.pth") + state_path = os.path.join(tmpdir, "state.pth") + + torch.save(linear.state_dict(), state_path) + torch.save(state_dict_8bit, state_path_8bit) + + if not has_fp16_weights: + assert os.path.getsize(state_path_8bit) < 0.5 * os.path.getsize(state_path) + + new_state_dict = torch.load(state_path_8bit) new_linear_custom = Linear8bitLt( linear.in_features, @@ -99,13 +118,21 @@ def test_linear_serialization(has_fp16_weights): has_fp16_weights=has_fp16_weights, threshold=6.0, ) - linear_custom.state.force_no_igemmlt = True + + if deserialize_before_cuda: + with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError): + new_linear_custom.load_state_dict(new_state_dict, strict=True) + new_linear_custom = new_linear_custom.cuda() - new_linear_custom.load_state_dict(state_dict, strict=True) + + if not deserialize_before_cuda: + new_linear_custom.load_state_dict(new_state_dict, strict=True) x_second = x.clone().cuda().requires_grad_(True) fx_second = new_linear_custom(x_second).float() (fx_second * grad_proj).mean().backward() - assert torch.allclose(fx_first, fx_second, atol=1e-5) - assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5) + # if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised + if has_fp16_weights or not deserialize_before_cuda: + assert torch.allclose(fx_first, fx_second, atol=1e-5) + assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5)