Fixed two bugs in dynamic data type creation.

This commit is contained in:
Tim Dettmers 2023-08-03 19:47:15 -07:00
parent a06a0f6a08
commit 3c9aca9124
2 changed files with 24 additions and 17 deletions

View File

@ -322,10 +322,8 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
# these are additional items that come from the case # these are additional items that come from the case
# where all the exponent bits are zero and no # where all the exponent bits are zero and no
# indicator bit is present # indicator bit is present
non_sign_bits = total_bits - (1 if signed else 0) non_sign_bits = total_bits - (1 if signed else 1)
additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1 additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1
if not signed:
additional_items = 2 * additional_items
for i in range(max_exponent_bits): for i in range(max_exponent_bits):
fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1)) fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1))
boundaries = torch.linspace(0.1, 1, fraction_items) boundaries = torch.linspace(0.1, 1, fraction_items)
@ -344,6 +342,8 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
data.append(0) data.append(0)
data.append(1.0) data.append(1.0)
assert len(data) == 2**total_bits
gap = 256 - len(data) gap = 256 - len(data)
for i in range(gap): for i in range(gap):
data.append(0) data.append(0)

View File

@ -129,6 +129,7 @@ def test_quantile_quantization():
assert diff < 0.001 assert diff < 0.001
def test_dynamic_quantization(): def test_dynamic_quantization():
diffs = [] diffs = []
reldiffs = [] reldiffs = []
@ -141,8 +142,8 @@ def test_dynamic_quantization():
diffs.append(diff.mean().item()) diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item()) reldiffs.append(reldiff.mean().item())
assert diff.mean().item() < 0.0135 assert diff.mean().item() < 0.0135
# print(sum(diffs)/len(diffs)) print(sum(diffs)/len(diffs))
# print(sum(reldiffs)/len(reldiffs)) print(sum(reldiffs)/len(reldiffs))
for i in range(100): for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda") A1 = torch.rand(1024, 1024, device="cuda")
@ -157,7 +158,8 @@ def test_dynamic_quantization():
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"]) @pytest.mark.parametrize("nested", [False, True], ids=["False", "True"])
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64]) @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
def test_dynamic_blockwise_quantization(dtype, nested, blocksize): @pytest.mark.parametrize("signed", [True, False], ids=['signed_True', 'signed_False'])
def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
#print('') #print('')
diffs = [] diffs = []
reldiffs = [] reldiffs = []
@ -178,9 +180,10 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize):
assert A2.dtype == dtype assert A2.dtype == dtype
diffs = [] diffs = []
code = F.create_dynamic_map(signed=signed)
for i in range(100): for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype) A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype)
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested) C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code)
A2 = F.dequantize_blockwise(C, S) A2 = F.dequantize_blockwise(C, S)
diff = torch.abs(A1 - A2).float() diff = torch.abs(A1 - A2).float()
reldiff = diff / torch.abs(A1.float() + 1e-8) reldiff = diff / torch.abs(A1.float() + 1e-8)
@ -189,11 +192,15 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize):
#torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0) #torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
abserr = sum(diffs)/len(diffs) abserr = sum(diffs)/len(diffs)
relerr = sum(reldiffs)/len(reldiffs) relerr = sum(reldiffs)/len(reldiffs)
if signed:
assert abserr < 0.0035 assert abserr < 0.0035
assert relerr < 0.015 assert relerr < 0.015
else:
assert abserr < 0.00175
assert relerr < 0.012
assert A2.dtype == dtype assert A2.dtype == dtype
#print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs)) #print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
#print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) #print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))