Fixed a bug in absmax float conversion.

This commit is contained in:
Tim Dettmers 2023-07-13 21:47:38 -07:00
parent 67475257a9
commit c00402f17e
2 changed files with 2 additions and 1 deletions

View File

@ -268,6 +268,7 @@ Features:
Bug fixes:
- Fixed a bug where the default type of absmax was undefined which leads to errors if the default type is different than torch.float32. # 553
- Fixed a missing scipy dependency in requirements.txt. #544
- Fixed a bug, where a view operation could cause an error in 8-bit layers.
Documentation:
- Improved documentation for GPUs that do not support 8-bit matmul. #529

View File

@ -685,10 +685,10 @@ def dequantize_blockwise(
absmax, code, blocksize, nested, dtype, offset, state2 = quant_state
if absmax.dtype != torch.float32: absmax = absmax.float()
if nested:
absmax = dequantize_blockwise(absmax, state2)
absmax += offset
if absmax.dtype != torch.float32: absmax = absmax.float()
if out is None:
out = torch.empty(A.shape, dtype=dtype, device=A.device)