Merge pull request #159 from TimDettmers/serialize_8bit
Implement proper serialization of Linear8bitLt
This commit is contained in:
commit
ed6f3eb146
|
@ -234,7 +234,7 @@ def supports_igemmlt(device: torch.device) -> bool:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MatmulLtState:
|
class MatmulLtState:
|
||||||
tile_indices: Optional[torch.Tensor] = None
|
_tile_indices: Optional[torch.Tensor] = None
|
||||||
force_no_igemmlt: bool = False
|
force_no_igemmlt: bool = False
|
||||||
CB = None
|
CB = None
|
||||||
CxB = None
|
CxB = None
|
||||||
|
@ -274,6 +274,15 @@ class MatmulLtState:
|
||||||
), f"please find this assert and manually enter tile size for {self.formatB}"
|
), f"please find this assert and manually enter tile size for {self.formatB}"
|
||||||
return (8, 32) if self.formatB == "col_turing" else (32, 32)
|
return (8, 32) if self.formatB == "col_turing" else (32, 32)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tile_indices(self):
|
||||||
|
if self._tile_indices is None:
|
||||||
|
device = self.CxB.device
|
||||||
|
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=self.formatB)[0].to(x.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
self._tile_indices = get_inverse_transform_indices(transform, self.get_tile_size()).to(device)
|
||||||
|
return self._tile_indices
|
||||||
|
|
||||||
|
|
||||||
class MatMul8bitLt(torch.autograd.Function):
|
class MatMul8bitLt(torch.autograd.Function):
|
||||||
# forward is the same, but we added the fallback for pre-turing GPUs
|
# forward is the same, but we added the fallback for pre-turing GPUs
|
||||||
|
@ -466,13 +475,6 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
|
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
|
||||||
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
|
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||||
elif state.CxB is not None:
|
elif state.CxB is not None:
|
||||||
|
|
||||||
if state.tile_indices is None:
|
|
||||||
order, tile_size = state.formatB, state.get_tile_size()
|
|
||||||
transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
|
|
||||||
with torch.no_grad():
|
|
||||||
state.tile_indices = get_inverse_transform_indices(transform, tile_size).to(state.CxB.device)
|
|
||||||
|
|
||||||
CB = (
|
CB = (
|
||||||
undo_layout(state.CxB, state.tile_indices)
|
undo_layout(state.CxB, state.tile_indices)
|
||||||
.to(ctx.dtype_A)
|
.to(ctx.dtype_A)
|
||||||
|
|
|
@ -9,6 +9,8 @@ import torch.nn.functional as F
|
||||||
from torch import Tensor, device, dtype, nn
|
from torch import Tensor, device, dtype, nn
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
import bitsandbytes.functional
|
||||||
|
from bitsandbytes.autograd._functions import get_inverse_transform_indices, undo_layout
|
||||||
from bitsandbytes.optim import GlobalOptimManager
|
from bitsandbytes.optim import GlobalOptimManager
|
||||||
|
|
||||||
T = TypeVar("T", bound="torch.nn.Module")
|
T = TypeVar("T", bound="torch.nn.Module")
|
||||||
|
@ -224,6 +226,53 @@ class Linear8bitLt(nn.Linear):
|
||||||
|
|
||||||
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
|
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
|
||||||
|
|
||||||
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
if not self.state.has_fp16_weights and self.state.CB is None and self.state.CxB is not None:
|
||||||
|
# reorder weight layout back from ampere/turing to row
|
||||||
|
reorder_layout = True
|
||||||
|
weight_clone = self.weight.data.clone()
|
||||||
|
else:
|
||||||
|
reorder_layout = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
if reorder_layout:
|
||||||
|
self.weight.data = undo_layout(self.state.CxB, self.state.tile_indices)
|
||||||
|
|
||||||
|
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||||
|
|
||||||
|
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
|
||||||
|
weight_name = "SCB"
|
||||||
|
|
||||||
|
# case 1: .cuda was called, SCB is in self.weight
|
||||||
|
param_from_weight = getattr(self.weight, weight_name)
|
||||||
|
# case 2: self.init_8bit_state was called, SCB is in self.state
|
||||||
|
param_from_state = getattr(self.state, weight_name)
|
||||||
|
|
||||||
|
key_name = prefix + f"{weight_name}"
|
||||||
|
if param_from_weight is not None:
|
||||||
|
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
|
||||||
|
elif not self.state.has_fp16_weights and param_from_state is not None:
|
||||||
|
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
|
||||||
|
finally:
|
||||||
|
if reorder_layout:
|
||||||
|
self.weight.data = weight_clone
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
||||||
|
missing_keys, unexpected_keys, error_msgs):
|
||||||
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||||
|
error_msgs)
|
||||||
|
for key in unexpected_keys:
|
||||||
|
input_name = key[len(prefix):]
|
||||||
|
if input_name == "SCB":
|
||||||
|
if self.weight.SCB is None:
|
||||||
|
# buffers not yet initialized, can't call them directly without
|
||||||
|
raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is "
|
||||||
|
"not supported. Please call module.cuda() before module.load_state_dict()")
|
||||||
|
|
||||||
|
input_param = state_dict[key]
|
||||||
|
self.weight.SCB.copy_(input_param)
|
||||||
|
unexpected_keys.remove(key)
|
||||||
|
|
||||||
def init_8bit_state(self):
|
def init_8bit_state(self):
|
||||||
self.state.CB = self.weight.CB
|
self.state.CB = self.weight.CB
|
||||||
self.state.SCB = self.weight.SCB
|
self.state.SCB = self.weight.SCB
|
||||||
|
|
|
@ -1,11 +1,17 @@
|
||||||
import bitsandbytes as bnb
|
import os
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from itertools import product
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from bitsandbytes import functional as F
|
|
||||||
|
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
from bitsandbytes import functional as F
|
||||||
from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout
|
from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout
|
||||||
from bitsandbytes.nn.modules import Linear8bitLt
|
from bitsandbytes.nn.modules import Linear8bitLt
|
||||||
|
|
||||||
|
|
||||||
# contributed by Alex Borzunov, see:
|
# contributed by Alex Borzunov, see:
|
||||||
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
|
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
|
||||||
|
|
||||||
|
@ -26,6 +32,7 @@ def test_layout_exact_match():
|
||||||
assert restored_x.is_contiguous()
|
assert restored_x.is_contiguous()
|
||||||
assert torch.all(torch.eq(restored_x, x))
|
assert torch.all(torch.eq(restored_x, x))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
||||||
def test_linear_no_igemmlt():
|
def test_linear_no_igemmlt():
|
||||||
linear = torch.nn.Linear(1024, 3072)
|
linear = torch.nn.Linear(1024, 3072)
|
||||||
|
@ -43,7 +50,7 @@ def test_linear_no_igemmlt():
|
||||||
linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
|
linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
|
||||||
).to(linear.weight.dtype)
|
).to(linear.weight.dtype)
|
||||||
linear_custom.bias = linear.bias
|
linear_custom.bias = linear.bias
|
||||||
linear = linear_custom.cuda()
|
linear_custom = linear_custom.cuda()
|
||||||
linear = linear.half().cuda()
|
linear = linear.half().cuda()
|
||||||
|
|
||||||
x_ref = x.clone().cuda().requires_grad_(True)
|
x_ref = x.clone().cuda().requires_grad_(True)
|
||||||
|
@ -59,3 +66,78 @@ def test_linear_no_igemmlt():
|
||||||
assert not linear_custom.state.has_fp16_weights
|
assert not linear_custom.state.has_fp16_weights
|
||||||
assert linear_custom.state.CB is not None
|
assert linear_custom.state.CB is not None
|
||||||
assert linear_custom.state.CxB is None
|
assert linear_custom.state.CxB is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
||||||
|
@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt",
|
||||||
|
list(product([False, True], [False, True], [False, True], [False, True])))
|
||||||
|
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt):
|
||||||
|
linear = torch.nn.Linear(32, 96)
|
||||||
|
x = torch.randn(3, 32, dtype=torch.half)
|
||||||
|
|
||||||
|
linear_custom = Linear8bitLt(
|
||||||
|
linear.in_features,
|
||||||
|
linear.out_features,
|
||||||
|
linear.bias is not None,
|
||||||
|
has_fp16_weights=has_fp16_weights,
|
||||||
|
threshold=6.0,
|
||||||
|
)
|
||||||
|
if force_no_igemmlt:
|
||||||
|
linear_custom.state.force_no_igemmlt = True
|
||||||
|
|
||||||
|
linear_custom.weight = bnb.nn.Int8Params(
|
||||||
|
linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights
|
||||||
|
)
|
||||||
|
linear_custom.bias = linear.bias
|
||||||
|
linear_custom = linear_custom.cuda()
|
||||||
|
|
||||||
|
if serialize_before_forward:
|
||||||
|
state_dict_8bit = linear_custom.state_dict()
|
||||||
|
|
||||||
|
x_first = x.clone().cuda().requires_grad_(True)
|
||||||
|
fx_first = linear_custom(x_first).float()
|
||||||
|
grad_proj = torch.randn_like(fx_first)
|
||||||
|
(fx_first * grad_proj).mean().backward()
|
||||||
|
|
||||||
|
if not serialize_before_forward:
|
||||||
|
state_dict_8bit = linear_custom.state_dict()
|
||||||
|
|
||||||
|
with TemporaryDirectory() as tmpdir:
|
||||||
|
state_path_8bit = os.path.join(tmpdir, "state_8bit.pth")
|
||||||
|
state_path = os.path.join(tmpdir, "state.pth")
|
||||||
|
|
||||||
|
torch.save(linear.state_dict(), state_path)
|
||||||
|
torch.save(state_dict_8bit, state_path_8bit)
|
||||||
|
|
||||||
|
if not has_fp16_weights:
|
||||||
|
assert os.path.getsize(state_path_8bit) < 0.5 * os.path.getsize(state_path)
|
||||||
|
|
||||||
|
new_state_dict = torch.load(state_path_8bit)
|
||||||
|
|
||||||
|
new_linear_custom = Linear8bitLt(
|
||||||
|
linear.in_features,
|
||||||
|
linear.out_features,
|
||||||
|
linear.bias is not None,
|
||||||
|
has_fp16_weights=has_fp16_weights,
|
||||||
|
threshold=6.0,
|
||||||
|
)
|
||||||
|
if force_no_igemmlt:
|
||||||
|
new_linear_custom.state.force_no_igemmlt = True
|
||||||
|
|
||||||
|
if deserialize_before_cuda:
|
||||||
|
with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError):
|
||||||
|
new_linear_custom.load_state_dict(new_state_dict, strict=True)
|
||||||
|
|
||||||
|
new_linear_custom = new_linear_custom.cuda()
|
||||||
|
|
||||||
|
if not deserialize_before_cuda:
|
||||||
|
new_linear_custom.load_state_dict(new_state_dict, strict=True)
|
||||||
|
|
||||||
|
x_second = x.clone().cuda().requires_grad_(True)
|
||||||
|
fx_second = new_linear_custom(x_second).float()
|
||||||
|
(fx_second * grad_proj).mean().backward()
|
||||||
|
|
||||||
|
# if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised
|
||||||
|
if has_fp16_weights or not deserialize_before_cuda:
|
||||||
|
assert torch.allclose(fx_first, fx_second, atol=1e-5)
|
||||||
|
assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user