This commit is contained in:
parent
8a20cd864b
commit
67475257a9
|
@ -267,3 +267,8 @@ Features:
|
|||
|
||||
Bug fixes:
|
||||
- Fixed a bug where the default type of absmax was undefined which leads to errors if the default type is different than torch.float32. # 553
|
||||
- Fixed a missing scipy dependency in requirements.txt. #544
|
||||
|
||||
Documentation:
|
||||
- Improved documentation for GPUs that do not support 8-bit matmul. #529
|
||||
- Added description and pointers for the NF4 data type. #543
|
||||
|
|
|
@ -163,7 +163,8 @@ def is_cublasLt_compatible(cc):
|
|||
if cc is not None:
|
||||
cc_major, cc_minor = cc.split('.')
|
||||
if int(cc_major) < 7 or (int(cc_major) == 7 and int(cc_minor) < 5):
|
||||
CUDASetup.get_instance().add_log_entry("WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU!", is_warning=True)
|
||||
CUDASetup.get_instance().add_log_entry("WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU! \
|
||||
If you run into issues with 8-bit matmul, you can try 4-bit quantization: https://huggingface.co/blog/4bit-transformers-bitsandbytes", is_warning=True)
|
||||
else:
|
||||
has_cublaslt = True
|
||||
return has_cublaslt
|
||||
|
|
|
@ -718,6 +718,16 @@ def get_4bit_type(typename, device=None, blocksize=64):
|
|||
if device is None: device = 'cuda'
|
||||
data = None
|
||||
if typename == 'nf4':
|
||||
''' Implements the NF4 data type.
|
||||
|
||||
Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that
|
||||
is normalized into the range [-1, 1].
|
||||
|
||||
For more information read the paper: QLoRA: Efficient Finetuning of Quantized LLMs (https://arxiv.org/abs/2305.14314)
|
||||
|
||||
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
|
||||
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
|
||||
'''
|
||||
data = [-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635,
|
||||
-0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725,
|
||||
0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941,
|
||||
|
@ -731,6 +741,7 @@ def get_4bit_type(typename, device=None, blocksize=64):
|
|||
# 0b101 = 6
|
||||
# 0b110 = 2
|
||||
# 0b111 = 3
|
||||
# can also be created with bnb.functional.create_fp8_map(signed=True, exponent_bits=2, precision_bits=1, total_bits=4)
|
||||
data = [0, 0.0625, 8.0, 12.0, 4.0, 6.0, 2.0, 3.0, -0, -0.0625, -8.0, -12.0, -4.0, -6.0, -2.0, -3.0]
|
||||
elif typename == 'int4':
|
||||
data = [7, 6, 5, 4, 3, 2, 1, 0, -0, -1, -2, -3, -4, -5, -6, -7]
|
||||
|
@ -888,10 +899,10 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
|
|||
|
||||
|
||||
if compressed_stats is not None:
|
||||
if absmax.dtype != torch.float32: absmax = absmax.float()
|
||||
offset, state2 = compressed_stats
|
||||
absmax = dequantize_blockwise(absmax, state2)
|
||||
absmax += offset
|
||||
if absmax.dtype != torch.float32: absmax = absmax.float()
|
||||
|
||||
if out is None:
|
||||
out = torch.empty(shape, dtype=dtype, device=A.device)
|
||||
|
|
|
@ -229,6 +229,16 @@ class LinearFP4(Linear4bit):
|
|||
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', device)
|
||||
|
||||
class LinearNF4(Linear4bit):
|
||||
''' Implements the NF4 data type.
|
||||
|
||||
Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that
|
||||
is normalized into the range [-1, 1].
|
||||
|
||||
For more information read the paper: QLoRA: Efficient Finetuning of Quantized LLMs (https://arxiv.org/abs/2305.14314)
|
||||
|
||||
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
|
||||
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
|
||||
'''
|
||||
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None):
|
||||
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user