Fixed bug caused by undefined default type of absmax. #553

This commit is contained in:
Tim Dettmers 2023-07-13 21:23:33 -07:00
parent 7b6cfe1738
commit 097b1cc5da
2 changed files with 9 additions and 2 deletions

View File

@ -264,3 +264,6 @@ Deprecated:
Features:
- Added precompiled CUDA 11.8 binaries to support H100 GPUs without compilation #571
Bug fixes:
- Fixed a bug where the default type of absmax was undefined which leads to errors if the default type is different than torch.float32. # 553

View File

@ -604,7 +604,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou
n = A.numel()
blocks = n // blocksize
blocks += 1 if n % blocksize > 0 else 0
absmax = torch.zeros((blocks,), device=A.device)
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
if out is None:
out = torch.zeros_like(A, dtype=torch.uint8)
@ -684,6 +684,8 @@ def dequantize_blockwise(
quant_state = (absmax, code, blocksize, False, torch.float32, None, None)
absmax, code, blocksize, nested, dtype, offset, state2 = quant_state
if absmax.dtype != torch.float32: absmax = absmax.float()
if nested:
absmax = dequantize_blockwise(absmax, state2)
absmax += offset
@ -796,7 +798,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
if absmax is None:
blocks = n // blocksize
blocks += 1 if n % blocksize > 0 else 0
absmax = torch.zeros((blocks,), device=A.device)
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
if out is None:
@ -886,6 +888,7 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
if compressed_stats is not None:
if absmax.dtype != torch.float32: absmax = absmax.float()
offset, state2 = compressed_stats
absmax = dequantize_blockwise(absmax, state2)
absmax += offset
@ -930,6 +933,7 @@ def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor:
code = code.to(A.device)
absmax = torch.abs(A).max()
if absmax.dtype != torch.float32: absmax = absmax.float()
inp = A / absmax
out = quantize_no_absmax(inp, code, out)
return out, (absmax, code)