From b8ea2b416d25130ed32a3cf436b8a9f8fd1d412f Mon Sep 17 00:00:00 2001 From: Tim Dettmers <tim.dettmers@gmail.com> Date: Wed, 12 Apr 2023 12:28:35 -0700 Subject: [PATCH] Fixed bias conversion in Linear4bit --- bitsandbytes/nn/modules.py | 34 +--------------------------------- 1 file changed, 1 insertion(+), 33 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index de9e4ac..ab16e01 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -205,45 +205,13 @@ class Linear4bit(nn.Linear): if self.compute_dtype is not None: x = x.to(self.compute_dtype) - bias = None if self.bias is None else self.bias.half(self.compute_dtype) + bias = None if self.bias is None else self.bias.to(self.compute_dtype) out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state) out = out.to(inp_dtype) return out - def _save_to_state_dict(self, destination, prefix, keep_vars): - super()._save_to_state_dict(destination, prefix, keep_vars) - - # we only need to save extra state if .cuda was called - # then we have the (1) quantization weight and the (2) quantization config - - #quant_state = getattr(self.weight, 'quant_state', None) - #if quant_state is not None: - # # 2. quantization state - # destination[prefix + 'quant_state'] = quant_state - - #destination[prefix + 'weight'] = self.weight.detach() - - - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): - super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs) - #for key in unexpected_keys: - # input_name = key[len(prefix):] - # if input_name == "quant_state": - # if getattr(self.weight, 'quant_state', None) is None: - # # buffers not yet initialized, can't call them directly without - # raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear4bit is " - # "not supported. Please call module.cuda() before module.load_state_dict()") - - # input_param = state_dict[key] - # self.weight.quant_state = input_param - # assert isinstance(self.weight, Param4bit) - # unexpected_keys.remove(key) - class LinearFP4(Linear4bit): def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True): super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4')