Fixed bias conversion in Linear4bit
This commit is contained in:
parent
e9fa03b717
commit
b8ea2b416d
|
@ -205,45 +205,13 @@ class Linear4bit(nn.Linear):
|
||||||
if self.compute_dtype is not None:
|
if self.compute_dtype is not None:
|
||||||
x = x.to(self.compute_dtype)
|
x = x.to(self.compute_dtype)
|
||||||
|
|
||||||
bias = None if self.bias is None else self.bias.half(self.compute_dtype)
|
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
|
||||||
out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
|
out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
|
||||||
|
|
||||||
out = out.to(inp_dtype)
|
out = out.to(inp_dtype)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
||||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
|
||||||
|
|
||||||
# we only need to save extra state if .cuda was called
|
|
||||||
# then we have the (1) quantization weight and the (2) quantization config
|
|
||||||
|
|
||||||
#quant_state = getattr(self.weight, 'quant_state', None)
|
|
||||||
#if quant_state is not None:
|
|
||||||
# # 2. quantization state
|
|
||||||
# destination[prefix + 'quant_state'] = quant_state
|
|
||||||
|
|
||||||
#destination[prefix + 'weight'] = self.weight.detach()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
|
||||||
missing_keys, unexpected_keys, error_msgs):
|
|
||||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
|
||||||
error_msgs)
|
|
||||||
#for key in unexpected_keys:
|
|
||||||
# input_name = key[len(prefix):]
|
|
||||||
# if input_name == "quant_state":
|
|
||||||
# if getattr(self.weight, 'quant_state', None) is None:
|
|
||||||
# # buffers not yet initialized, can't call them directly without
|
|
||||||
# raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear4bit is "
|
|
||||||
# "not supported. Please call module.cuda() before module.load_state_dict()")
|
|
||||||
|
|
||||||
# input_param = state_dict[key]
|
|
||||||
# self.weight.quant_state = input_param
|
|
||||||
# assert isinstance(self.weight, Param4bit)
|
|
||||||
# unexpected_keys.remove(key)
|
|
||||||
|
|
||||||
class LinearFP4(Linear4bit):
|
class LinearFP4(Linear4bit):
|
||||||
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True):
|
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True):
|
||||||
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4')
|
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4')
|
||||||
|
|
Loading…
Reference in New Issue
Block a user