Handle more cases in test_linear_serialization

This commit is contained in:
Max Ryabinin 2023-02-25 06:01:04 +01:00
parent 58b09ee1b1
commit ac3ab281e3

View File

@ -1,4 +1,7 @@
from copy import deepcopy
import os
from contextlib import nullcontext
from itertools import product
from tempfile import TemporaryDirectory
import pytest
import torch
@ -66,10 +69,11 @@ def test_linear_no_igemmlt():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize("has_fp16_weights", [False, True])
def test_linear_serialization(has_fp16_weights):
linear = torch.nn.Linear(16, 32)
x = torch.randn(3, 16, dtype=torch.half)
@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda",
list(product([False, True], [False, True], [False, True])))
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda):
linear = torch.nn.Linear(32, 96)
x = torch.randn(3, 32, dtype=torch.half)
linear_custom = Linear8bitLt(
linear.in_features,
@ -78,19 +82,34 @@ def test_linear_serialization(has_fp16_weights):
has_fp16_weights=has_fp16_weights,
threshold=6.0,
)
linear_custom.state.force_no_igemmlt = True
linear_custom.weight = bnb.nn.Int8Params(
linear.weight.data.clone(), requires_grad=False, has_fp16_weights=has_fp16_weights
).to(linear.weight.dtype)
linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights
)
linear_custom.bias = linear.bias
linear_custom = linear_custom.cuda()
if serialize_before_forward:
state_dict_8bit = linear_custom.state_dict()
x_first = x.clone().cuda().requires_grad_(True)
fx_first = linear_custom(x_first).float()
grad_proj = torch.randn_like(fx_first)
(fx_first * grad_proj).mean().backward()
state_dict = deepcopy(linear_custom.state_dict())
if not serialize_before_forward:
state_dict_8bit = linear_custom.state_dict()
with TemporaryDirectory() as tmpdir:
state_path_8bit = os.path.join(tmpdir, "state_8bit.pth")
state_path = os.path.join(tmpdir, "state.pth")
torch.save(linear.state_dict(), state_path)
torch.save(state_dict_8bit, state_path_8bit)
if not has_fp16_weights:
assert os.path.getsize(state_path_8bit) < 0.5 * os.path.getsize(state_path)
new_state_dict = torch.load(state_path_8bit)
new_linear_custom = Linear8bitLt(
linear.in_features,
@ -99,13 +118,21 @@ def test_linear_serialization(has_fp16_weights):
has_fp16_weights=has_fp16_weights,
threshold=6.0,
)
linear_custom.state.force_no_igemmlt = True
if deserialize_before_cuda:
with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError):
new_linear_custom.load_state_dict(new_state_dict, strict=True)
new_linear_custom = new_linear_custom.cuda()
new_linear_custom.load_state_dict(state_dict, strict=True)
if not deserialize_before_cuda:
new_linear_custom.load_state_dict(new_state_dict, strict=True)
x_second = x.clone().cuda().requires_grad_(True)
fx_second = new_linear_custom(x_second).float()
(fx_second * grad_proj).mean().backward()
assert torch.allclose(fx_first, fx_second, atol=1e-5)
assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5)
# if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised
if has_fp16_weights or not deserialize_before_cuda:
assert torch.allclose(fx_first, fx_second, atol=1e-5)
assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5)