Added nested quantization for blockwise quantization.
This commit is contained in:
parent
7dc198feb7
commit
0f9d30207f
|
@ -541,7 +541,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
|
|||
return out
|
||||
|
||||
|
||||
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None, blocksize=4096) -> Tensor:
|
||||
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor:
|
||||
"""
|
||||
Quantize tensor A in blocks of size 4096 values.
|
||||
|
||||
|
@ -586,7 +586,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
|||
out = torch.zeros_like(A, dtype=torch.uint8)
|
||||
|
||||
if A.device.type != 'cpu':
|
||||
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]
|
||||
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
|
||||
cblocksize = ct.c_int32(blocksize)
|
||||
prev_device = pre_call(A.device)
|
||||
code = code.to(A.device)
|
||||
|
@ -616,7 +616,15 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
|||
assert rand is None
|
||||
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
|
||||
|
||||
state = [absmax, code, blocksize]
|
||||
if nested:
|
||||
offset = absmax.mean()
|
||||
absmax -= offset
|
||||
qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False)
|
||||
state = [qabsmax, code, blocksize, nested, offset, state2]
|
||||
else:
|
||||
state = [absmax, code, blocksize, nested, None, None]
|
||||
|
||||
|
||||
|
||||
return out, state
|
||||
|
||||
|
@ -628,6 +636,7 @@ def dequantize_blockwise(
|
|||
code: Tensor = None,
|
||||
out: Tensor = None,
|
||||
blocksize: int = 4096,
|
||||
nested=False
|
||||
) -> Tensor:
|
||||
"""
|
||||
Dequantizes blockwise quantized values.
|
||||
|
@ -665,13 +674,15 @@ def dequantize_blockwise(
|
|||
if quant_state is None:
|
||||
quant_state = (absmax, code, blocksize)
|
||||
else:
|
||||
absmax, code, blocksize = quant_state
|
||||
|
||||
absmax, code, blocksize, nested, offset, state2 = quant_state
|
||||
if nested:
|
||||
absmax = dequantize_blockwise(absmax, state2)
|
||||
absmax += offset
|
||||
|
||||
if A.device.type != 'cpu':
|
||||
device = pre_call(A.device)
|
||||
code = code.to(A.device)
|
||||
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64, 32]:
|
||||
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
|
||||
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
|
||||
is_on_gpu([A, absmax, out])
|
||||
if out.dtype == torch.float32:
|
||||
|
@ -736,7 +747,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
|
|||
if out is None:
|
||||
out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device)
|
||||
|
||||
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]
|
||||
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
|
||||
|
||||
prev_device = pre_call(A.device)
|
||||
is_on_gpu([A, out, absmax])
|
||||
|
|
|
@ -150,42 +150,44 @@ def test_dynamic_quantization():
|
|||
assert diff < 0.004
|
||||
|
||||
|
||||
def test_dynamic_blockwise_quantization():
|
||||
#print('')
|
||||
for blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]:
|
||||
diffs = []
|
||||
reldiffs = []
|
||||
for i in range(100):
|
||||
A1 = torch.randn(1024, 1024, device="cuda")
|
||||
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
|
||||
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff / torch.abs(A1 + 1e-8)
|
||||
diffs.append(diff.mean().item())
|
||||
reldiffs.append(reldiff.mean().item())
|
||||
abserr = sum(diffs)/len(diffs)
|
||||
relerr = sum(reldiffs)/len(reldiffs)
|
||||
assert abserr < 0.011
|
||||
assert relerr < 0.018
|
||||
#print('randn', blocksize, sum(diffs)/len(diffs))
|
||||
#print('randn', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
|
||||
diffs = []
|
||||
for i in range(100):
|
||||
A1 = torch.rand(1024, 1024, device="cuda")
|
||||
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
|
||||
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff / torch.abs(A1 + 1e-8)
|
||||
diffs.append(diff.mean().item())
|
||||
reldiffs.append(reldiff.mean().item())
|
||||
#torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
|
||||
abserr = sum(diffs)/len(diffs)
|
||||
relerr = sum(reldiffs)/len(reldiffs)
|
||||
assert abserr < 0.0035
|
||||
assert relerr < 0.015
|
||||
#print('rand', blocksize, sum(diffs)/len(diffs))
|
||||
#print('rand', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"])
|
||||
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
|
||||
def test_dynamic_blockwise_quantization(nested, blocksize):
|
||||
#print('')
|
||||
diffs = []
|
||||
reldiffs = []
|
||||
for i in range(100):
|
||||
A1 = torch.randn(1024, 1024, device="cuda")
|
||||
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
|
||||
A2 = F.dequantize_blockwise(C, S)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff / torch.abs(A1 + 1e-8)
|
||||
diffs.append(diff.mean().item())
|
||||
reldiffs.append(reldiff.mean().item())
|
||||
abserr = sum(diffs)/len(diffs)
|
||||
relerr = sum(reldiffs)/len(reldiffs)
|
||||
assert abserr < 0.011
|
||||
assert relerr < 0.018
|
||||
print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs))
|
||||
print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
|
||||
diffs = []
|
||||
for i in range(100):
|
||||
A1 = torch.rand(1024, 1024, device="cuda")
|
||||
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
|
||||
A2 = F.dequantize_blockwise(C, S)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff / torch.abs(A1 + 1e-8)
|
||||
diffs.append(diff.mean().item())
|
||||
reldiffs.append(reldiff.mean().item())
|
||||
#torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
|
||||
abserr = sum(diffs)/len(diffs)
|
||||
relerr = sum(reldiffs)/len(reldiffs)
|
||||
assert abserr < 0.0035
|
||||
assert relerr < 0.015
|
||||
print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
|
||||
print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
|
||||
|
||||
def test_dynamic_blockwise_stochastic_quantization():
|
||||
|
|
Loading…
Reference in New Issue
Block a user