From b8ea2b416d25130ed32a3cf436b8a9f8fd1d412f Mon Sep 17 00:00:00 2001
From: Tim Dettmers <tim.dettmers@gmail.com>
Date: Wed, 12 Apr 2023 12:28:35 -0700
Subject: [PATCH] Fixed bias conversion in Linear4bit

---
 bitsandbytes/nn/modules.py | 34 +---------------------------------
 1 file changed, 1 insertion(+), 33 deletions(-)

diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index de9e4ac..ab16e01 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -205,45 +205,13 @@ class Linear4bit(nn.Linear):
         if self.compute_dtype is not None:
             x = x.to(self.compute_dtype)
 
-        bias = None if self.bias is None else self.bias.half(self.compute_dtype)
+        bias = None if self.bias is None else self.bias.to(self.compute_dtype)
         out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
 
         out = out.to(inp_dtype)
 
         return out
 
-    def _save_to_state_dict(self, destination, prefix, keep_vars):
-        super()._save_to_state_dict(destination, prefix, keep_vars)
-
-        # we only need to save extra state if .cuda was called
-        # then we have the (1) quantization weight and the (2) quantization config
-
-        #quant_state = getattr(self.weight, 'quant_state', None)
-        #if quant_state is not None:
-        #    # 2. quantization state
-        #    destination[prefix + 'quant_state'] = quant_state
-
-        #destination[prefix + 'weight'] = self.weight.detach()
-
-
-
-    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
-                              missing_keys, unexpected_keys, error_msgs):
-        super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
-                                      error_msgs)
-        #for key in unexpected_keys:
-        #    input_name = key[len(prefix):]
-        #    if input_name == "quant_state":
-        #        if getattr(self.weight, 'quant_state', None) is None:
-        #            # buffers not yet initialized, can't call them directly without
-        #            raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear4bit is "
-        #                               "not supported. Please call module.cuda() before module.load_state_dict()")
-
-        #        input_param = state_dict[key]
-        #        self.weight.quant_state = input_param
-        #        assert isinstance(self.weight, Param4bit)
-        #        unexpected_keys.remove(key)
-
 class LinearFP4(Linear4bit):
     def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True):
         super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4')