Revert the layout if weights were reordered

This commit is contained in:
Max Ryabinin 2023-02-25 06:02:06 +01:00
parent cd4d904a4c
commit cc608c04c2

View File

@ -9,6 +9,8 @@ import torch.nn.functional as F
from torch import Tensor, device, dtype, nn
import bitsandbytes as bnb
import bitsandbytes.functional
from bitsandbytes.autograd._functions import get_inverse_transform_indices, undo_layout
from bitsandbytes.optim import GlobalOptimManager
T = TypeVar("T", bound="torch.nn.Module")
@ -225,6 +227,30 @@ class Linear8bitLt(nn.Linear):
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
def _save_to_state_dict(self, destination, prefix, keep_vars):
if not self.state.has_fp16_weights and self.state.CB is None and self.state.CxB is not None:
# reorder weight layout back from ampere/turing to row
reorder_layout = True
weight_clone = self.weight.data.clone()
else:
reorder_layout = False
try:
if reorder_layout:
if self.state.tile_indices is None:
order, tile_size = self.state.formatB, self.state.get_tile_size()
transform = lambda x: \
bitsandbytes.functional.transform(x.to(self.weight.data.device), from_order="row",
to_order=order)[0].to(x.device)
with torch.no_grad():
self.state.tile_indices = get_inverse_transform_indices(transform, tile_size).to(
self.state.CxB.device)
CB = (
undo_layout(self.state.CxB, self.state.tile_indices)
)
self.weight.data = CB
super()._save_to_state_dict(destination, prefix, keep_vars)
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
@ -240,6 +266,9 @@ class Linear8bitLt(nn.Linear):
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
elif not self.state.has_fp16_weights and param_from_state is not None:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
finally:
if reorder_layout:
self.weight.data = weight_clone
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):