Handle more cases in test_linear_serialization
This commit is contained in:
parent
58b09ee1b1
commit
ac3ab281e3
|
@ -1,4 +1,7 @@
|
|||
from copy import deepcopy
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from itertools import product
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -66,10 +69,11 @@ def test_linear_no_igemmlt():
|
|||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
||||
@pytest.mark.parametrize("has_fp16_weights", [False, True])
|
||||
def test_linear_serialization(has_fp16_weights):
|
||||
linear = torch.nn.Linear(16, 32)
|
||||
x = torch.randn(3, 16, dtype=torch.half)
|
||||
@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda",
|
||||
list(product([False, True], [False, True], [False, True])))
|
||||
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda):
|
||||
linear = torch.nn.Linear(32, 96)
|
||||
x = torch.randn(3, 32, dtype=torch.half)
|
||||
|
||||
linear_custom = Linear8bitLt(
|
||||
linear.in_features,
|
||||
|
@ -78,19 +82,34 @@ def test_linear_serialization(has_fp16_weights):
|
|||
has_fp16_weights=has_fp16_weights,
|
||||
threshold=6.0,
|
||||
)
|
||||
linear_custom.state.force_no_igemmlt = True
|
||||
linear_custom.weight = bnb.nn.Int8Params(
|
||||
linear.weight.data.clone(), requires_grad=False, has_fp16_weights=has_fp16_weights
|
||||
).to(linear.weight.dtype)
|
||||
linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights
|
||||
)
|
||||
linear_custom.bias = linear.bias
|
||||
linear_custom = linear_custom.cuda()
|
||||
|
||||
if serialize_before_forward:
|
||||
state_dict_8bit = linear_custom.state_dict()
|
||||
|
||||
x_first = x.clone().cuda().requires_grad_(True)
|
||||
fx_first = linear_custom(x_first).float()
|
||||
grad_proj = torch.randn_like(fx_first)
|
||||
(fx_first * grad_proj).mean().backward()
|
||||
|
||||
state_dict = deepcopy(linear_custom.state_dict())
|
||||
if not serialize_before_forward:
|
||||
state_dict_8bit = linear_custom.state_dict()
|
||||
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
state_path_8bit = os.path.join(tmpdir, "state_8bit.pth")
|
||||
state_path = os.path.join(tmpdir, "state.pth")
|
||||
|
||||
torch.save(linear.state_dict(), state_path)
|
||||
torch.save(state_dict_8bit, state_path_8bit)
|
||||
|
||||
if not has_fp16_weights:
|
||||
assert os.path.getsize(state_path_8bit) < 0.5 * os.path.getsize(state_path)
|
||||
|
||||
new_state_dict = torch.load(state_path_8bit)
|
||||
|
||||
new_linear_custom = Linear8bitLt(
|
||||
linear.in_features,
|
||||
|
@ -99,13 +118,21 @@ def test_linear_serialization(has_fp16_weights):
|
|||
has_fp16_weights=has_fp16_weights,
|
||||
threshold=6.0,
|
||||
)
|
||||
linear_custom.state.force_no_igemmlt = True
|
||||
|
||||
if deserialize_before_cuda:
|
||||
with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError):
|
||||
new_linear_custom.load_state_dict(new_state_dict, strict=True)
|
||||
|
||||
new_linear_custom = new_linear_custom.cuda()
|
||||
new_linear_custom.load_state_dict(state_dict, strict=True)
|
||||
|
||||
if not deserialize_before_cuda:
|
||||
new_linear_custom.load_state_dict(new_state_dict, strict=True)
|
||||
|
||||
x_second = x.clone().cuda().requires_grad_(True)
|
||||
fx_second = new_linear_custom(x_second).float()
|
||||
(fx_second * grad_proj).mean().backward()
|
||||
|
||||
assert torch.allclose(fx_first, fx_second, atol=1e-5)
|
||||
assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5)
|
||||
# if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised
|
||||
if has_fp16_weights or not deserialize_before_cuda:
|
||||
assert torch.allclose(fx_first, fx_second, atol=1e-5)
|
||||
assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5)
|
||||
|
|
Loading…
Reference in New Issue
Block a user