From c1f3f56d2cc18c929dc9b257a24603d26657b0b7 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Fri, 9 Jun 2023 21:58:39 +0200 Subject: [PATCH] Rearrange the weights directly in state dict before loading --- bitsandbytes/nn/modules.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 101c988..b806e94 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -306,6 +306,15 @@ class Int8Params(torch.nn.Parameter): return new_param +def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + weight = state_dict[f"{prefix}weight"] + weight_format = state_dict.pop(f"{prefix}weight_format", "row") + + if weight_format != "row": + tile_indices = get_tile_inds(weight_format, weight.device) + state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices) + + class Linear8bitLt(nn.Linear): def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0, index=None): @@ -321,6 +330,7 @@ class Linear8bitLt(nn.Linear): self.state.use_pool = True self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights) + self._register_load_state_dict_pre_hook(maybe_rearrange_weight) def _save_to_state_dict(self, destination, prefix, keep_vars): super()._save_to_state_dict(destination, prefix, keep_vars) @@ -370,12 +380,6 @@ class Linear8bitLt(nn.Linear): self.state.SCB = self.weight.SCB unexpected_keys.remove(key) - if input_name == "weight_format": - input_param = state_dict[key] - if input_param != "row": - tile_indices = get_tile_inds(input_param, self.weight.device) - self.weight.data = self.weight.CB = undo_layout(self.weight.data, tile_indices) - unexpected_keys.remove(key) def init_8bit_state(self): self.state.CB = self.weight.CB