Fixed bug caused by undefined default type of absmax. #553
This commit is contained in:
parent
7b6cfe1738
commit
097b1cc5da
|
@ -264,3 +264,6 @@ Deprecated:
|
|||
|
||||
Features:
|
||||
- Added precompiled CUDA 11.8 binaries to support H100 GPUs without compilation #571
|
||||
|
||||
Bug fixes:
|
||||
- Fixed a bug where the default type of absmax was undefined which leads to errors if the default type is different than torch.float32. # 553
|
||||
|
|
|
@ -604,7 +604,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou
|
|||
n = A.numel()
|
||||
blocks = n // blocksize
|
||||
blocks += 1 if n % blocksize > 0 else 0
|
||||
absmax = torch.zeros((blocks,), device=A.device)
|
||||
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
|
||||
|
||||
if out is None:
|
||||
out = torch.zeros_like(A, dtype=torch.uint8)
|
||||
|
@ -684,6 +684,8 @@ def dequantize_blockwise(
|
|||
quant_state = (absmax, code, blocksize, False, torch.float32, None, None)
|
||||
|
||||
absmax, code, blocksize, nested, dtype, offset, state2 = quant_state
|
||||
|
||||
if absmax.dtype != torch.float32: absmax = absmax.float()
|
||||
if nested:
|
||||
absmax = dequantize_blockwise(absmax, state2)
|
||||
absmax += offset
|
||||
|
@ -796,7 +798,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
|
|||
if absmax is None:
|
||||
blocks = n // blocksize
|
||||
blocks += 1 if n % blocksize > 0 else 0
|
||||
absmax = torch.zeros((blocks,), device=A.device)
|
||||
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
|
||||
|
||||
|
||||
if out is None:
|
||||
|
@ -886,6 +888,7 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
|
|||
|
||||
|
||||
if compressed_stats is not None:
|
||||
if absmax.dtype != torch.float32: absmax = absmax.float()
|
||||
offset, state2 = compressed_stats
|
||||
absmax = dequantize_blockwise(absmax, state2)
|
||||
absmax += offset
|
||||
|
@ -930,6 +933,7 @@ def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor:
|
|||
code = code.to(A.device)
|
||||
|
||||
absmax = torch.abs(A).max()
|
||||
if absmax.dtype != torch.float32: absmax = absmax.float()
|
||||
inp = A / absmax
|
||||
out = quantize_no_absmax(inp, code, out)
|
||||
return out, (absmax, code)
|
||||
|
|
Loading…
Reference in New Issue
Block a user