2023-02-25 05:01:04 +00:00
|
|
|
import os
|
|
|
|
from contextlib import nullcontext
|
|
|
|
from itertools import product
|
|
|
|
from tempfile import TemporaryDirectory
|
2023-02-21 11:04:47 +00:00
|
|
|
|
2023-02-02 04:09:31 +00:00
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
|
2023-02-21 11:04:47 +00:00
|
|
|
import bitsandbytes as bnb
|
|
|
|
from bitsandbytes import functional as F
|
2023-02-02 04:09:31 +00:00
|
|
|
from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout
|
|
|
|
from bitsandbytes.nn.modules import Linear8bitLt
|
|
|
|
|
2023-02-21 11:04:47 +00:00
|
|
|
|
2023-02-02 04:09:31 +00:00
|
|
|
# contributed by Alex Borzunov, see:
|
|
|
|
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
|
|
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
|
|
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5),
|
|
|
|
reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs",
|
|
|
|
)
|
|
|
|
def test_layout_exact_match():
|
|
|
|
x = (torch.randn(14336 * 3, 14336) * 10).to(torch.int8).cuda()
|
|
|
|
for tile_size, order in ((8, 32), "col_turing"), ((32, 32), "col_ampere"):
|
|
|
|
transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
|
|
|
|
tile_indices = get_inverse_transform_indices(transform, tile_size)
|
|
|
|
cxb = transform(x)
|
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
restored_x = undo_layout(cxb, tile_indices)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
assert restored_x.is_contiguous()
|
|
|
|
assert torch.all(torch.eq(restored_x, x))
|
|
|
|
|
2023-02-21 11:04:47 +00:00
|
|
|
|
2023-02-02 04:09:31 +00:00
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
|
|
|
def test_linear_no_igemmlt():
|
|
|
|
linear = torch.nn.Linear(1024, 3072)
|
|
|
|
x = torch.randn(3, 1024, dtype=torch.half)
|
|
|
|
linear_custom = Linear8bitLt(
|
|
|
|
linear.in_features,
|
|
|
|
linear.out_features,
|
|
|
|
linear.bias is not None,
|
|
|
|
has_fp16_weights=False,
|
|
|
|
threshold=6.0,
|
|
|
|
)
|
|
|
|
linear_custom.state.force_no_igemmlt = True
|
|
|
|
|
|
|
|
linear_custom.weight = bnb.nn.Int8Params(
|
|
|
|
linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
|
|
|
|
).to(linear.weight.dtype)
|
|
|
|
linear_custom.bias = linear.bias
|
2023-02-21 11:04:47 +00:00
|
|
|
linear_custom = linear_custom.cuda()
|
2023-02-02 04:09:31 +00:00
|
|
|
linear = linear.half().cuda()
|
|
|
|
|
|
|
|
x_ref = x.clone().cuda().requires_grad_(True)
|
|
|
|
x_ours = x.clone().cuda().requires_grad_(True)
|
|
|
|
fx_ref = linear(x_ref).float()
|
|
|
|
grad_proj = torch.randn_like(fx_ref)
|
|
|
|
(fx_ref * grad_proj).mean().backward()
|
|
|
|
|
|
|
|
fx_ours = linear_custom(x_ours).float()
|
|
|
|
(fx_ours * grad_proj).mean().backward()
|
|
|
|
assert torch.allclose(fx_ref, fx_ours, atol=0.02)
|
|
|
|
assert torch.allclose(x_ref.grad, x_ours.grad, atol=0.01)
|
|
|
|
assert not linear_custom.state.has_fp16_weights
|
|
|
|
assert linear_custom.state.CB is not None
|
|
|
|
assert linear_custom.state.CxB is None
|
2023-02-21 11:04:47 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
2023-03-21 23:28:49 +00:00
|
|
|
@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt",
|
|
|
|
list(product([False, True], [False, True], [False, True], [False, True])))
|
|
|
|
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt):
|
2023-02-25 05:01:04 +00:00
|
|
|
linear = torch.nn.Linear(32, 96)
|
|
|
|
x = torch.randn(3, 32, dtype=torch.half)
|
2023-02-21 11:04:47 +00:00
|
|
|
|
|
|
|
linear_custom = Linear8bitLt(
|
|
|
|
linear.in_features,
|
|
|
|
linear.out_features,
|
|
|
|
linear.bias is not None,
|
|
|
|
has_fp16_weights=has_fp16_weights,
|
|
|
|
threshold=6.0,
|
|
|
|
)
|
2023-03-21 23:28:49 +00:00
|
|
|
if force_no_igemmlt:
|
|
|
|
linear_custom.state.force_no_igemmlt = True
|
|
|
|
|
2023-02-21 11:04:47 +00:00
|
|
|
linear_custom.weight = bnb.nn.Int8Params(
|
2023-02-25 05:01:04 +00:00
|
|
|
linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights
|
|
|
|
)
|
2023-02-21 11:04:47 +00:00
|
|
|
linear_custom.bias = linear.bias
|
|
|
|
linear_custom = linear_custom.cuda()
|
|
|
|
|
2023-02-25 05:01:04 +00:00
|
|
|
if serialize_before_forward:
|
|
|
|
state_dict_8bit = linear_custom.state_dict()
|
|
|
|
|
2023-02-21 11:04:47 +00:00
|
|
|
x_first = x.clone().cuda().requires_grad_(True)
|
|
|
|
fx_first = linear_custom(x_first).float()
|
|
|
|
grad_proj = torch.randn_like(fx_first)
|
|
|
|
(fx_first * grad_proj).mean().backward()
|
|
|
|
|
2023-02-25 05:01:04 +00:00
|
|
|
if not serialize_before_forward:
|
|
|
|
state_dict_8bit = linear_custom.state_dict()
|
|
|
|
|
|
|
|
with TemporaryDirectory() as tmpdir:
|
|
|
|
state_path_8bit = os.path.join(tmpdir, "state_8bit.pth")
|
|
|
|
state_path = os.path.join(tmpdir, "state.pth")
|
|
|
|
|
|
|
|
torch.save(linear.state_dict(), state_path)
|
|
|
|
torch.save(state_dict_8bit, state_path_8bit)
|
|
|
|
|
|
|
|
if not has_fp16_weights:
|
|
|
|
assert os.path.getsize(state_path_8bit) < 0.5 * os.path.getsize(state_path)
|
|
|
|
|
|
|
|
new_state_dict = torch.load(state_path_8bit)
|
2023-02-21 11:04:47 +00:00
|
|
|
|
|
|
|
new_linear_custom = Linear8bitLt(
|
|
|
|
linear.in_features,
|
|
|
|
linear.out_features,
|
|
|
|
linear.bias is not None,
|
|
|
|
has_fp16_weights=has_fp16_weights,
|
|
|
|
threshold=6.0,
|
|
|
|
)
|
2023-03-21 23:28:49 +00:00
|
|
|
if force_no_igemmlt:
|
|
|
|
new_linear_custom.state.force_no_igemmlt = True
|
2023-02-25 05:01:04 +00:00
|
|
|
|
|
|
|
if deserialize_before_cuda:
|
|
|
|
with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError):
|
|
|
|
new_linear_custom.load_state_dict(new_state_dict, strict=True)
|
|
|
|
|
2023-02-21 11:04:47 +00:00
|
|
|
new_linear_custom = new_linear_custom.cuda()
|
2023-02-25 05:01:04 +00:00
|
|
|
|
|
|
|
if not deserialize_before_cuda:
|
|
|
|
new_linear_custom.load_state_dict(new_state_dict, strict=True)
|
2023-02-21 11:04:47 +00:00
|
|
|
|
|
|
|
x_second = x.clone().cuda().requires_grad_(True)
|
|
|
|
fx_second = new_linear_custom(x_second).float()
|
|
|
|
(fx_second * grad_proj).mean().backward()
|
|
|
|
|
2023-02-25 05:01:04 +00:00
|
|
|
# if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised
|
|
|
|
if has_fp16_weights or not deserialize_before_cuda:
|
|
|
|
assert torch.allclose(fx_first, fx_second, atol=1e-5)
|
|
|
|
assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5)
|