Fixed bias conversion in Linear4bit

This commit is contained in:
Tim Dettmers 2023-04-12 12:28:35 -07:00
parent e9fa03b717
commit b8ea2b416d

View File

@ -205,45 +205,13 @@ class Linear4bit(nn.Linear):
if self.compute_dtype is not None:
x = x.to(self.compute_dtype)
bias = None if self.bias is None else self.bias.half(self.compute_dtype)
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
out = out.to(inp_dtype)
return out
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
# we only need to save extra state if .cuda was called
# then we have the (1) quantization weight and the (2) quantization config
#quant_state = getattr(self.weight, 'quant_state', None)
#if quant_state is not None:
# # 2. quantization state
# destination[prefix + 'quant_state'] = quant_state
#destination[prefix + 'weight'] = self.weight.detach()
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs)
#for key in unexpected_keys:
# input_name = key[len(prefix):]
# if input_name == "quant_state":
# if getattr(self.weight, 'quant_state', None) is None:
# # buffers not yet initialized, can't call them directly without
# raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear4bit is "
# "not supported. Please call module.cuda() before module.load_state_dict()")
# input_param = state_dict[key]
# self.weight.quant_state = input_param
# assert isinstance(self.weight, Param4bit)
# unexpected_keys.remove(key)
class LinearFP4(Linear4bit):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4')