From c00402f17e0483ede3fa841f6c4e0031a4f72a34 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Thu, 13 Jul 2023 21:47:38 -0700 Subject: [PATCH] Fixed a bug in absmax float conversion. --- CHANGELOG.md | 1 + bitsandbytes/functional.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 54b7611..32400cc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -268,6 +268,7 @@ Features: Bug fixes: - Fixed a bug where the default type of absmax was undefined which leads to errors if the default type is different than torch.float32. # 553 - Fixed a missing scipy dependency in requirements.txt. #544 + - Fixed a bug, where a view operation could cause an error in 8-bit layers. Documentation: - Improved documentation for GPUs that do not support 8-bit matmul. #529 diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index c2956a7..837f6bf 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -685,10 +685,10 @@ def dequantize_blockwise( absmax, code, blocksize, nested, dtype, offset, state2 = quant_state - if absmax.dtype != torch.float32: absmax = absmax.float() if nested: absmax = dequantize_blockwise(absmax, state2) absmax += offset + if absmax.dtype != torch.float32: absmax = absmax.float() if out is None: out = torch.empty(A.shape, dtype=dtype, device=A.device)