Fixed two bugs in dynamic data type creation.
This commit is contained in:
parent
a06a0f6a08
commit
3c9aca9124
|
@ -322,10 +322,8 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
|
||||||
# these are additional items that come from the case
|
# these are additional items that come from the case
|
||||||
# where all the exponent bits are zero and no
|
# where all the exponent bits are zero and no
|
||||||
# indicator bit is present
|
# indicator bit is present
|
||||||
non_sign_bits = total_bits - (1 if signed else 0)
|
non_sign_bits = total_bits - (1 if signed else 1)
|
||||||
additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1
|
additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1
|
||||||
if not signed:
|
|
||||||
additional_items = 2 * additional_items
|
|
||||||
for i in range(max_exponent_bits):
|
for i in range(max_exponent_bits):
|
||||||
fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1))
|
fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1))
|
||||||
boundaries = torch.linspace(0.1, 1, fraction_items)
|
boundaries = torch.linspace(0.1, 1, fraction_items)
|
||||||
|
@ -334,16 +332,18 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
|
||||||
if signed:
|
if signed:
|
||||||
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
|
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
|
||||||
|
|
||||||
if additional_items > 0:
|
if additional_items > 0:
|
||||||
boundaries = torch.linspace(0.1, 1, additional_items + 1)
|
boundaries = torch.linspace(0.1, 1, additional_items + 1)
|
||||||
means = (boundaries[:-1] + boundaries[1:]) / 2.0
|
means = (boundaries[:-1] + boundaries[1:]) / 2.0
|
||||||
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
|
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
|
||||||
if signed:
|
if signed:
|
||||||
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
|
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
|
||||||
|
|
||||||
data.append(0)
|
data.append(0)
|
||||||
data.append(1.0)
|
data.append(1.0)
|
||||||
|
|
||||||
|
assert len(data) == 2**total_bits
|
||||||
|
|
||||||
gap = 256 - len(data)
|
gap = 256 - len(data)
|
||||||
for i in range(gap):
|
for i in range(gap):
|
||||||
data.append(0)
|
data.append(0)
|
||||||
|
|
|
@ -129,6 +129,7 @@ def test_quantile_quantization():
|
||||||
assert diff < 0.001
|
assert diff < 0.001
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_dynamic_quantization():
|
def test_dynamic_quantization():
|
||||||
diffs = []
|
diffs = []
|
||||||
reldiffs = []
|
reldiffs = []
|
||||||
|
@ -141,8 +142,8 @@ def test_dynamic_quantization():
|
||||||
diffs.append(diff.mean().item())
|
diffs.append(diff.mean().item())
|
||||||
reldiffs.append(reldiff.mean().item())
|
reldiffs.append(reldiff.mean().item())
|
||||||
assert diff.mean().item() < 0.0135
|
assert diff.mean().item() < 0.0135
|
||||||
# print(sum(diffs)/len(diffs))
|
print(sum(diffs)/len(diffs))
|
||||||
# print(sum(reldiffs)/len(reldiffs))
|
print(sum(reldiffs)/len(reldiffs))
|
||||||
|
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
A1 = torch.rand(1024, 1024, device="cuda")
|
A1 = torch.rand(1024, 1024, device="cuda")
|
||||||
|
@ -157,7 +158,8 @@ def test_dynamic_quantization():
|
||||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
|
||||||
@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"])
|
@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"])
|
||||||
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
|
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
|
||||||
def test_dynamic_blockwise_quantization(dtype, nested, blocksize):
|
@pytest.mark.parametrize("signed", [True, False], ids=['signed_True', 'signed_False'])
|
||||||
|
def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
|
||||||
#print('')
|
#print('')
|
||||||
diffs = []
|
diffs = []
|
||||||
reldiffs = []
|
reldiffs = []
|
||||||
|
@ -178,9 +180,10 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize):
|
||||||
assert A2.dtype == dtype
|
assert A2.dtype == dtype
|
||||||
|
|
||||||
diffs = []
|
diffs = []
|
||||||
|
code = F.create_dynamic_map(signed=signed)
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype)
|
A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype)
|
||||||
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
|
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code)
|
||||||
A2 = F.dequantize_blockwise(C, S)
|
A2 = F.dequantize_blockwise(C, S)
|
||||||
diff = torch.abs(A1 - A2).float()
|
diff = torch.abs(A1 - A2).float()
|
||||||
reldiff = diff / torch.abs(A1.float() + 1e-8)
|
reldiff = diff / torch.abs(A1.float() + 1e-8)
|
||||||
|
@ -189,11 +192,15 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize):
|
||||||
#torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
|
#torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
|
||||||
abserr = sum(diffs)/len(diffs)
|
abserr = sum(diffs)/len(diffs)
|
||||||
relerr = sum(reldiffs)/len(reldiffs)
|
relerr = sum(reldiffs)/len(reldiffs)
|
||||||
assert abserr < 0.0035
|
if signed:
|
||||||
assert relerr < 0.015
|
assert abserr < 0.0035
|
||||||
|
assert relerr < 0.015
|
||||||
|
else:
|
||||||
|
assert abserr < 0.00175
|
||||||
|
assert relerr < 0.012
|
||||||
assert A2.dtype == dtype
|
assert A2.dtype == dtype
|
||||||
#print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
|
#print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
|
||||||
#print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
|
#print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user