diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d0b5d8..b503b35 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -264,3 +264,6 @@ Deprecated: Features: - Added precompiled CUDA 11.8 binaries to support H100 GPUs without compilation #571 + +Bug fixes: + - Fixed a bug where the default type of absmax was undefined which leads to errors if the default type is different than torch.float32. # 553 diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 033ae32..9ad8daa 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -604,7 +604,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou n = A.numel() blocks = n // blocksize blocks += 1 if n % blocksize > 0 else 0 - absmax = torch.zeros((blocks,), device=A.device) + absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) if out is None: out = torch.zeros_like(A, dtype=torch.uint8) @@ -684,6 +684,8 @@ def dequantize_blockwise( quant_state = (absmax, code, blocksize, False, torch.float32, None, None) absmax, code, blocksize, nested, dtype, offset, state2 = quant_state + + if absmax.dtype != torch.float32: absmax = absmax.float() if nested: absmax = dequantize_blockwise(absmax, state2) absmax += offset @@ -796,7 +798,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz if absmax is None: blocks = n // blocksize blocks += 1 if n % blocksize > 0 else 0 - absmax = torch.zeros((blocks,), device=A.device) + absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) if out is None: @@ -886,6 +888,7 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: if compressed_stats is not None: + if absmax.dtype != torch.float32: absmax = absmax.float() offset, state2 = compressed_stats absmax = dequantize_blockwise(absmax, state2) absmax += offset @@ -930,6 +933,7 @@ def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: code = code.to(A.device) absmax = torch.abs(A).max() + if absmax.dtype != torch.float32: absmax = absmax.float() inp = A / absmax out = quantize_no_absmax(inp, code, out) return out, (absmax, code)