[WIP] Implement proper serialization of Linear8bitLt

This commit is contained in:
Max Ryabinin 2023-02-21 12:04:47 +01:00
parent 0f5c394870
commit 58b09ee1b1
2 changed files with 81 additions and 3 deletions

View File

@ -224,6 +224,34 @@ class Linear8bitLt(nn.Linear):
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
weight_name = "SCB"
# case 1: .cuda was called, SCB is in self.weight
param_from_weight = getattr(self.weight, weight_name)
# case 2: self.init_8bit_state was called, SCB is in self.state
param_from_state = getattr(self.state, weight_name)
key_name = prefix + f"{weight_name}"
if param_from_weight is not None:
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
elif not self.state.has_fp16_weights and param_from_state is not None:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs)
for key in unexpected_keys:
input_name = key[len(prefix):]
if input_name == "SCB":
input_param = state_dict[key]
self.weight.SCB.copy_(input_param)
unexpected_keys.remove(key)
def init_8bit_state(self):
self.state.CB = self.weight.CB
self.state.SCB = self.weight.SCB

View File

@ -1,11 +1,14 @@
import bitsandbytes as bnb
from copy import deepcopy
import pytest
import torch
from bitsandbytes import functional as F
import bitsandbytes as bnb
from bitsandbytes import functional as F
from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout
from bitsandbytes.nn.modules import Linear8bitLt
# contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
@ -26,6 +29,7 @@ def test_layout_exact_match():
assert restored_x.is_contiguous()
assert torch.all(torch.eq(restored_x, x))
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
def test_linear_no_igemmlt():
linear = torch.nn.Linear(1024, 3072)
@ -43,7 +47,7 @@ def test_linear_no_igemmlt():
linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
).to(linear.weight.dtype)
linear_custom.bias = linear.bias
linear = linear_custom.cuda()
linear_custom = linear_custom.cuda()
linear = linear.half().cuda()
x_ref = x.clone().cuda().requires_grad_(True)
@ -59,3 +63,49 @@ def test_linear_no_igemmlt():
assert not linear_custom.state.has_fp16_weights
assert linear_custom.state.CB is not None
assert linear_custom.state.CxB is None
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize("has_fp16_weights", [False, True])
def test_linear_serialization(has_fp16_weights):
linear = torch.nn.Linear(16, 32)
x = torch.randn(3, 16, dtype=torch.half)
linear_custom = Linear8bitLt(
linear.in_features,
linear.out_features,
linear.bias is not None,
has_fp16_weights=has_fp16_weights,
threshold=6.0,
)
linear_custom.state.force_no_igemmlt = True
linear_custom.weight = bnb.nn.Int8Params(
linear.weight.data.clone(), requires_grad=False, has_fp16_weights=has_fp16_weights
).to(linear.weight.dtype)
linear_custom.bias = linear.bias
linear_custom = linear_custom.cuda()
x_first = x.clone().cuda().requires_grad_(True)
fx_first = linear_custom(x_first).float()
grad_proj = torch.randn_like(fx_first)
(fx_first * grad_proj).mean().backward()
state_dict = deepcopy(linear_custom.state_dict())
new_linear_custom = Linear8bitLt(
linear.in_features,
linear.out_features,
linear.bias is not None,
has_fp16_weights=has_fp16_weights,
threshold=6.0,
)
linear_custom.state.force_no_igemmlt = True
new_linear_custom = new_linear_custom.cuda()
new_linear_custom.load_state_dict(state_dict, strict=True)
x_second = x.clone().cuda().requires_grad_(True)
fx_second = new_linear_custom(x_second).float()
(fx_second * grad_proj).mean().backward()
assert torch.allclose(fx_first, fx_second, atol=1e-5)
assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5)