Added nested quantization for blockwise quantization.

This commit is contained in:
Tim Dettmers 2023-04-19 11:48:47 -07:00
parent 7dc198feb7
commit 0f9d30207f
2 changed files with 55 additions and 42 deletions

View File

@ -541,7 +541,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
return out
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None, blocksize=4096) -> Tensor:
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor:
"""
Quantize tensor A in blocks of size 4096 values.
@ -586,7 +586,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
out = torch.zeros_like(A, dtype=torch.uint8)
if A.device.type != 'cpu':
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
cblocksize = ct.c_int32(blocksize)
prev_device = pre_call(A.device)
code = code.to(A.device)
@ -616,7 +616,15 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
assert rand is None
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
state = [absmax, code, blocksize]
if nested:
offset = absmax.mean()
absmax -= offset
qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False)
state = [qabsmax, code, blocksize, nested, offset, state2]
else:
state = [absmax, code, blocksize, nested, None, None]
return out, state
@ -628,6 +636,7 @@ def dequantize_blockwise(
code: Tensor = None,
out: Tensor = None,
blocksize: int = 4096,
nested=False
) -> Tensor:
"""
Dequantizes blockwise quantized values.
@ -665,13 +674,15 @@ def dequantize_blockwise(
if quant_state is None:
quant_state = (absmax, code, blocksize)
else:
absmax, code, blocksize = quant_state
absmax, code, blocksize, nested, offset, state2 = quant_state
if nested:
absmax = dequantize_blockwise(absmax, state2)
absmax += offset
if A.device.type != 'cpu':
device = pre_call(A.device)
code = code.to(A.device)
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64, 32]:
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
is_on_gpu([A, absmax, out])
if out.dtype == torch.float32:
@ -736,7 +747,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
if out is None:
out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device)
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
prev_device = pre_call(A.device)
is_on_gpu([A, out, absmax])

View File

@ -150,15 +150,17 @@ def test_dynamic_quantization():
assert diff < 0.004
def test_dynamic_blockwise_quantization():
@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"])
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
def test_dynamic_blockwise_quantization(nested, blocksize):
#print('')
for blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]:
diffs = []
reldiffs = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
A2 = F.dequantize_blockwise(C, S)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
@ -167,14 +169,14 @@ def test_dynamic_blockwise_quantization():
relerr = sum(reldiffs)/len(reldiffs)
assert abserr < 0.011
assert relerr < 0.018
#print('randn', blocksize, sum(diffs)/len(diffs))
#print('randn', blocksize, sum(reldiffs)/len(reldiffs))
print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs))
print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs))
diffs = []
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda")
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
A2 = F.dequantize_blockwise(C, S)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
@ -184,8 +186,8 @@ def test_dynamic_blockwise_quantization():
relerr = sum(reldiffs)/len(reldiffs)
assert abserr < 0.0035
assert relerr < 0.015
#print('rand', blocksize, sum(diffs)/len(diffs))
#print('rand', blocksize, sum(reldiffs)/len(reldiffs))
print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
def test_dynamic_blockwise_stochastic_quantization():