Fixed a bug in absmax float conversion.
This commit is contained in:
parent
67475257a9
commit
c00402f17e
|
@ -268,6 +268,7 @@ Features:
|
||||||
Bug fixes:
|
Bug fixes:
|
||||||
- Fixed a bug where the default type of absmax was undefined which leads to errors if the default type is different than torch.float32. # 553
|
- Fixed a bug where the default type of absmax was undefined which leads to errors if the default type is different than torch.float32. # 553
|
||||||
- Fixed a missing scipy dependency in requirements.txt. #544
|
- Fixed a missing scipy dependency in requirements.txt. #544
|
||||||
|
- Fixed a bug, where a view operation could cause an error in 8-bit layers.
|
||||||
|
|
||||||
Documentation:
|
Documentation:
|
||||||
- Improved documentation for GPUs that do not support 8-bit matmul. #529
|
- Improved documentation for GPUs that do not support 8-bit matmul. #529
|
||||||
|
|
|
@ -685,10 +685,10 @@ def dequantize_blockwise(
|
||||||
|
|
||||||
absmax, code, blocksize, nested, dtype, offset, state2 = quant_state
|
absmax, code, blocksize, nested, dtype, offset, state2 = quant_state
|
||||||
|
|
||||||
if absmax.dtype != torch.float32: absmax = absmax.float()
|
|
||||||
if nested:
|
if nested:
|
||||||
absmax = dequantize_blockwise(absmax, state2)
|
absmax = dequantize_blockwise(absmax, state2)
|
||||||
absmax += offset
|
absmax += offset
|
||||||
|
if absmax.dtype != torch.float32: absmax = absmax.float()
|
||||||
|
|
||||||
if out is None:
|
if out is None:
|
||||||
out = torch.empty(A.shape, dtype=dtype, device=A.device)
|
out = torch.empty(A.shape, dtype=dtype, device=A.device)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user