Fixed bias conversion in Linear4bit
This commit is contained in:
parent
e9fa03b717
commit
b8ea2b416d
|
@ -205,45 +205,13 @@ class Linear4bit(nn.Linear):
|
|||
if self.compute_dtype is not None:
|
||||
x = x.to(self.compute_dtype)
|
||||
|
||||
bias = None if self.bias is None else self.bias.half(self.compute_dtype)
|
||||
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
|
||||
out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
|
||||
|
||||
out = out.to(inp_dtype)
|
||||
|
||||
return out
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||
|
||||
# we only need to save extra state if .cuda was called
|
||||
# then we have the (1) quantization weight and the (2) quantization config
|
||||
|
||||
#quant_state = getattr(self.weight, 'quant_state', None)
|
||||
#if quant_state is not None:
|
||||
# # 2. quantization state
|
||||
# destination[prefix + 'quant_state'] = quant_state
|
||||
|
||||
#destination[prefix + 'weight'] = self.weight.detach()
|
||||
|
||||
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs):
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||
error_msgs)
|
||||
#for key in unexpected_keys:
|
||||
# input_name = key[len(prefix):]
|
||||
# if input_name == "quant_state":
|
||||
# if getattr(self.weight, 'quant_state', None) is None:
|
||||
# # buffers not yet initialized, can't call them directly without
|
||||
# raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear4bit is "
|
||||
# "not supported. Please call module.cuda() before module.load_state_dict()")
|
||||
|
||||
# input_param = state_dict[key]
|
||||
# self.weight.quant_state = input_param
|
||||
# assert isinstance(self.weight, Param4bit)
|
||||
# unexpected_keys.remove(key)
|
||||
|
||||
class LinearFP4(Linear4bit):
|
||||
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True):
|
||||
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4')
|
||||
|
|
Loading…
Reference in New Issue
Block a user