Added quantization tree generation.

This commit is contained in:
Tim Dettmers 2023-04-02 14:42:45 -07:00
parent 0d332a641f
commit 4ad999d144
2 changed files with 17 additions and 1 deletions

View File

@ -218,7 +218,7 @@ def create_custom_map(seed=0, scale=0.01):
assert values.numel() == 256
return values
def create_normal_map(offset=0.966666, use_extra_value=True):
def create_normal_map(offset=0.9677083, use_extra_value=True):
if use_extra_value:
# one more positive value, this is an asymmetric type

View File

@ -2318,3 +2318,19 @@ def test_bench_fp4_dequant():
# torch.matmul(b, a.t())
#torch.cuda.synchronize()
#print((time.time()-t0)/iters*1e6)
def test_normal_map_tree():
code = F.create_normal_map()
values =code[:8].tolist() + code[-8:].tolist()
num_pivots = 1
while num_pivots <16:
idx = list(range(16//num_pivots//2, 16, 16//num_pivots))
print(idx)
num_pivots *= 2
pivots = []
for i in idx:
pivots.append((values[i-1]+values[i])/2)
print(pivots)