Added quantization tree generation.
This commit is contained in:
parent
0d332a641f
commit
4ad999d144
|
@ -218,7 +218,7 @@ def create_custom_map(seed=0, scale=0.01):
|
||||||
assert values.numel() == 256
|
assert values.numel() == 256
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def create_normal_map(offset=0.966666, use_extra_value=True):
|
def create_normal_map(offset=0.9677083, use_extra_value=True):
|
||||||
|
|
||||||
if use_extra_value:
|
if use_extra_value:
|
||||||
# one more positive value, this is an asymmetric type
|
# one more positive value, this is an asymmetric type
|
||||||
|
|
|
@ -2318,3 +2318,19 @@ def test_bench_fp4_dequant():
|
||||||
# torch.matmul(b, a.t())
|
# torch.matmul(b, a.t())
|
||||||
#torch.cuda.synchronize()
|
#torch.cuda.synchronize()
|
||||||
#print((time.time()-t0)/iters*1e6)
|
#print((time.time()-t0)/iters*1e6)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_normal_map_tree():
|
||||||
|
code = F.create_normal_map()
|
||||||
|
values =code[:8].tolist() + code[-8:].tolist()
|
||||||
|
num_pivots = 1
|
||||||
|
while num_pivots <16:
|
||||||
|
idx = list(range(16//num_pivots//2, 16, 16//num_pivots))
|
||||||
|
print(idx)
|
||||||
|
num_pivots *= 2
|
||||||
|
pivots = []
|
||||||
|
for i in idx:
|
||||||
|
pivots.append((values[i-1]+values[i])/2)
|
||||||
|
print(pivots)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user