Fixed a bug in absmax float conversion.
This commit is contained in:
parent
67475257a9
commit
c00402f17e
|
@ -268,6 +268,7 @@ Features:
|
|||
Bug fixes:
|
||||
- Fixed a bug where the default type of absmax was undefined which leads to errors if the default type is different than torch.float32. # 553
|
||||
- Fixed a missing scipy dependency in requirements.txt. #544
|
||||
- Fixed a bug, where a view operation could cause an error in 8-bit layers.
|
||||
|
||||
Documentation:
|
||||
- Improved documentation for GPUs that do not support 8-bit matmul. #529
|
||||
|
|
|
@ -685,10 +685,10 @@ def dequantize_blockwise(
|
|||
|
||||
absmax, code, blocksize, nested, dtype, offset, state2 = quant_state
|
||||
|
||||
if absmax.dtype != torch.float32: absmax = absmax.float()
|
||||
if nested:
|
||||
absmax = dequantize_blockwise(absmax, state2)
|
||||
absmax += offset
|
||||
if absmax.dtype != torch.float32: absmax = absmax.float()
|
||||
|
||||
if out is None:
|
||||
out = torch.empty(A.shape, dtype=dtype, device=A.device)
|
||||
|
|
Loading…
Reference in New Issue
Block a user