diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 63b7156..c2298c8 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -232,6 +232,19 @@ def supports_igemmlt(device: torch.device) -> bool: return True +def _get_tile_size(format): + assert format in ( + "col_turing", + "col_ampere", + ), f"please find this assert and manually enter tile size for {format}" + return (8, 32) if format == "col_turing" else (32, 32) + + +def get_tile_inds(format, device): + transform = lambda x: F.transform(x.to(device), from_order="row", to_order=format)[0].to(x.device) + with torch.no_grad(): + return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device) + @dataclass class MatmulLtState: _tile_indices: Optional[torch.Tensor] = None @@ -267,20 +280,10 @@ class MatmulLtState: self.SBt = None self.CBt = None - def get_tile_size(self): - assert self.formatB in ( - "col_turing", - "col_ampere", - ), f"please find this assert and manually enter tile size for {self.formatB}" - return (8, 32) if self.formatB == "col_turing" else (32, 32) - @property def tile_indices(self): if self._tile_indices is None: - device = self.CxB.device - transform = lambda x: F.transform(x.to(device), from_order="row", to_order=self.formatB)[0].to(x.device) - with torch.no_grad(): - self._tile_indices = get_inverse_transform_indices(transform, self.get_tile_size()).to(device) + self._tile_indices = get_tile_inds(self.formatB, self.CxB.device) return self._tile_indices diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 3284921..b10d45a 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -10,7 +10,7 @@ from torch import Tensor, device, dtype, nn import bitsandbytes as bnb import bitsandbytes.functional -from bitsandbytes.autograd._functions import get_inverse_transform_indices, undo_layout +from bitsandbytes.autograd._functions import undo_layout, get_tile_inds from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.utils import OutlierTracer, find_outlier_dims @@ -306,6 +306,17 @@ class Int8Params(torch.nn.Parameter): return new_param +def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + weight = state_dict.get(f"{prefix}weight") + if weight is None: + # if the state dict has no weights for this layer (e.g., LoRA finetuning), do nothing + return + weight_format = state_dict.pop(f"{prefix}weight_format", "row") + + if weight_format != "row": + tile_indices = get_tile_inds(weight_format, weight.device) + state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices) + class Linear8bitLt(nn.Linear): def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, @@ -322,52 +333,55 @@ class Linear8bitLt(nn.Linear): self.state.use_pool = True self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights) + self._register_load_state_dict_pre_hook(maybe_rearrange_weight) def _save_to_state_dict(self, destination, prefix, keep_vars): - if not self.state.has_fp16_weights and self.state.CB is None and self.state.CxB is not None: - # reorder weight layout back from ampere/turing to row - reorder_layout = True - weight_clone = self.weight.data.clone() - else: - reorder_layout = False + super()._save_to_state_dict(destination, prefix, keep_vars) - try: - if reorder_layout: - self.weight.data = undo_layout(self.state.CxB, self.state.tile_indices) + # we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data + scb_name = "SCB" - super()._save_to_state_dict(destination, prefix, keep_vars) + # case 1: .cuda was called, SCB is in self.weight + param_from_weight = getattr(self.weight, scb_name) + # case 2: self.init_8bit_state was called, SCB is in self.state + param_from_state = getattr(self.state, scb_name) + # case 3: SCB is in self.state, weight layout reordered after first forward() + layout_reordered = self.state.CxB is not None - # we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data - weight_name = "SCB" + key_name = prefix + f"{scb_name}" + format_name = prefix + "weight_format" - # case 1: .cuda was called, SCB is in self.weight - param_from_weight = getattr(self.weight, weight_name) - # case 2: self.init_8bit_state was called, SCB is in self.state - param_from_state = getattr(self.state, weight_name) - - key_name = prefix + f"{weight_name}" + if not self.state.has_fp16_weights: if param_from_weight is not None: destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach() - elif not self.state.has_fp16_weights and param_from_state is not None: + destination[format_name] = "row" + elif param_from_state is not None and not layout_reordered: destination[key_name] = param_from_state if keep_vars else param_from_state.detach() - finally: - if reorder_layout: - self.weight.data = weight_clone + destination[format_name] = "row" + elif param_from_state is not None: + destination[key_name] = param_from_state if keep_vars else param_from_state.detach() + destination[format_name] = self.state.formatB def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - for key in unexpected_keys: + unexpected_copy = list(unexpected_keys) + + for key in unexpected_copy: input_name = key[len(prefix):] if input_name == "SCB": if self.weight.SCB is None: - # buffers not yet initialized, can't call them directly without + # buffers not yet initialized, can't access them directly without quantizing first raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is " "not supported. Please call module.cuda() before module.load_state_dict()") input_param = state_dict[key] self.weight.SCB.copy_(input_param) + + if self.state.SCB is not None: + self.state.SCB = self.weight.SCB + unexpected_keys.remove(key) def init_8bit_state(self):