Fixed a bug in absmax float conversion.

This commit is contained in:
Tim Dettmers 2023-07-13 21:47:38 -07:00
parent 67475257a9
commit c00402f17e
2 changed files with 2 additions and 1 deletions

View File

@ -268,6 +268,7 @@ Features:
Bug fixes: Bug fixes:
- Fixed a bug where the default type of absmax was undefined which leads to errors if the default type is different than torch.float32. # 553 - Fixed a bug where the default type of absmax was undefined which leads to errors if the default type is different than torch.float32. # 553
- Fixed a missing scipy dependency in requirements.txt. #544 - Fixed a missing scipy dependency in requirements.txt. #544
- Fixed a bug, where a view operation could cause an error in 8-bit layers.
Documentation: Documentation:
- Improved documentation for GPUs that do not support 8-bit matmul. #529 - Improved documentation for GPUs that do not support 8-bit matmul. #529

View File

@ -685,10 +685,10 @@ def dequantize_blockwise(
absmax, code, blocksize, nested, dtype, offset, state2 = quant_state absmax, code, blocksize, nested, dtype, offset, state2 = quant_state
if absmax.dtype != torch.float32: absmax = absmax.float()
if nested: if nested:
absmax = dequantize_blockwise(absmax, state2) absmax = dequantize_blockwise(absmax, state2)
absmax += offset absmax += offset
if absmax.dtype != torch.float32: absmax = absmax.float()
if out is None: if out is None:
out = torch.empty(A.shape, dtype=dtype, device=A.device) out = torch.empty(A.shape, dtype=dtype, device=A.device)