Added normal quant.
This commit is contained in:
parent
69810521d3
commit
8645d1f71c
|
@ -9,6 +9,8 @@ import random
|
|||
import torch
|
||||
import itertools
|
||||
import math
|
||||
import scipy.stats
|
||||
import numpy as np
|
||||
|
||||
from functools import reduce # Required in Python 3
|
||||
from typing import Tuple
|
||||
|
@ -152,6 +154,70 @@ def create_linear_map(signed=True, total_bits=8, add_zero=True):
|
|||
#return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist())
|
||||
return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist())
|
||||
|
||||
def custom_map(seed=0, scale=0.01):
|
||||
v = [12, 10, 8, 6, 3, 2, 1]
|
||||
# 16-bit 7B 22.33, 4-bit best 22.88, FP4 23.25, 4-bit 95 22.97, 4-bit evo 22.45
|
||||
# 16-bit 13B 70.35, 4-bit best 67.16, FP4 100.78, 4-bit-95 69.39, 4-bit evo 70.48
|
||||
|
||||
# 13B 100 steps:
|
||||
# - 4-bit evo: 86.02
|
||||
# - 4-bit norm: 78.73
|
||||
# - 4-bit FP4:
|
||||
# - 16-bit:
|
||||
|
||||
# interval search on normal distribution
|
||||
#v = [3.090232306167813, 1.4589770349449647, 1.064410327932115, 0.7896806653244509, 0.5646884166925807, 0.3653406435875121, 0.17964844284441311] # 0.999 26.5
|
||||
#v = [2.3263478740408408, 1.4050715603096329, 1.0364333894937898, 0.7721932141886848, 0.5533847195556727, 0.3584587932511938, 0.1763741647808615] # 0.99 24.99
|
||||
#v = [1.6448536269514722, 1.2040469600267016, 0.9208229763683788, 0.6971414348463417, 0.5039653672113453, 0.3280721075316511, 0.16184416680396213] # 0.95 24.53 22.97
|
||||
#v = [1.4050715603096329, 1.0803193408149558, 0.8416212335729143, 0.643345405392917, 0.4676987991145084, 0.3054807880993974, 0.1509692154967774] # 0.92 24.81
|
||||
#v = [1.2815515655446004, 1.0062699858608395, 0.7916386077433746, 0.6084981344998837, 0.4438613119262478, 0.29050677112339396, 0.14372923370582416] # 0.9 24.68
|
||||
#v = [1.8807936081512509, 1.2980047163986055, 0.9769954022693226, 0.7341502955472268, 0.5285136765472481, 0.343225833559403, 0.16910470304375366] # 0.97 25.03
|
||||
#v = [1.7506860712521692, 1.2496468758017434, 0.9485350408266378, 0.7155233557034365, 0.5162006366043174, 0.3356393360829622, 0.16547334454641704] # 0.96 24.85 23.01
|
||||
#v = [1.5547735945968535, 1.1608220210715001, 0.893800631179489, 0.6789921163940618, 0.4918050830048072, 0.3205236191093902, 0.15821711945563585] # 0.94 24.47
|
||||
#v = [1.475791028179171, 1.1196635980209986, 0.8674156943957149, 0.6610637542614526, 0.4797170937629045, 0.31299335020578195, 0.15459215234139795] # 0.93 24.85
|
||||
#v = [1.5981931399228175, 1.1821583959486879, 0.9072289939325966, 0.6880384454306778, 0.49787602226482025, 0.3242955535308664, 0.160030379970179] # 0.945 24.287
|
||||
##v = [1.6164363711150211, 1.1908453913294612, 0.9126463450304729, 0.6916727602238111, 0.5003095327012462, 0.3258056171348078, 0.1607558311941979] # 0.947 24.293
|
||||
#v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.207
|
||||
#v = [1.6118251211466303, 1.188665228776879, 0.9112895004060624, 0.690763326564427, 0.4997008778346997, 0.3254280317127771, 0.16057446047146948] # 0.9465 24.30
|
||||
#v = [1.6027040905517569, 1.184321770169049, 0.9085808314549837, 0.6889461706317986, 0.4984841229538408, 0.32467299997597887, 0.1602117348657326] # 0.9455 24.293
|
||||
#v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.37 22.88
|
||||
|
||||
# 7B evo start
|
||||
#v = [1.62129629, 1.18870191, 0.90848106, 0.69108646, 0.50515268, 0.34927819905, 0.14122701] # 22.06
|
||||
#v = [1.6143079205628337, 1.1888081407660314, 0.8990131955745421, 0.694373759813679, 0.5083033257326773, 0.3452499746844963, 0.1148939728228951]
|
||||
#v = [1.614442766030303, 1.189401918639665, 0.8998038168964273, 0.6953094818279475, 0.5073264599048384, 0.3449003790823619, 0.11428378427205564]
|
||||
|
||||
# 13B evo start
|
||||
#v = [1.6077535089716468, 1.1914902148179205, 0.8999752421085561, 0.6967904489387543, 0.4949093928311768, 0.30920472033044544, 0.15391602735952042]
|
||||
#v = [1.586363722436466, 1.202610827188916, 0.9003332576346587, 0.6904888715206972, 0.49490974688233724, 0.2971151461329376, 0.15683230810738283]
|
||||
v = [1.5842247437829478, 1.2037228884260156, 0.900369059187269, 0.6898587137788914, 0.4949097822874533, 0.2959061887131868, 0.15712393618216908]
|
||||
|
||||
# mean evo 7B + 13B
|
||||
#v = [1.5993337549066253, 1.1965624035328402, 0.9000864380418481, 0.6925840978034195, 0.5011181210961458, 0.32040328389777434, 0.13570386022711237]
|
||||
|
||||
# theoretically optiomal (0.93333)
|
||||
# v = [1.501085946044025, 1.1331700302595604, 0.8761428492468408, 0.6670160135425023, 0.48373855304610314, 0.3155014472579608, 0.15580024666388428] # 0.9333333333333333
|
||||
|
||||
|
||||
|
||||
if seed > 0:
|
||||
v = np.array(v)
|
||||
np.random.seed(seed)
|
||||
v += np.random.randn(7)*scale
|
||||
print(v.tolist())
|
||||
#v[0] += (np.random.randn(1)*0.001)[0]
|
||||
#v[-1] += (np.random.randn(1)*0.001)[0]
|
||||
#print(v[0], v[-1])
|
||||
v = v.tolist()
|
||||
values = v + [0]*(256-14) + \
|
||||
v[::-1]
|
||||
|
||||
values = torch.Tensor(values)
|
||||
values[0:7] *= -1
|
||||
values = values.sort().values
|
||||
values /= values.max()
|
||||
assert values.numel() == 256
|
||||
return values
|
||||
|
||||
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8):
|
||||
e = exponent_bits
|
||||
|
@ -168,7 +234,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
|
|||
values = []
|
||||
lst = list(itertools.product([0, 1], repeat=precision_bits))
|
||||
#for ev in evalues:
|
||||
bias = 2**(exponent_bits-1)+1
|
||||
bias = 2**(exponent_bits-1)-1
|
||||
for evalue in range(2**(exponent_bits)):
|
||||
for bit_pattern in lst:
|
||||
value = (1 if evalue != 0 else 0)
|
||||
|
@ -176,10 +242,10 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
|
|||
value += pval*(2**-(i+1))
|
||||
if evalue == 0:
|
||||
# subnormals
|
||||
value = value*2**-(bias)
|
||||
value = value*2**-(bias-1)
|
||||
else:
|
||||
# normals
|
||||
value = value*2**-(evalue-bias-1)
|
||||
value = value*2**-(evalue-bias-2)
|
||||
values.append(value)
|
||||
if signed:
|
||||
values.append(-value)
|
||||
|
@ -502,7 +568,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
|||
out = torch.zeros_like(A, dtype=torch.uint8)
|
||||
|
||||
if A.device.type != 'cpu':
|
||||
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
|
||||
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]
|
||||
cblocksize = ct.c_int32(blocksize)
|
||||
prev_device = pre_call(A.device)
|
||||
code = code.to(A.device)
|
||||
|
@ -585,7 +651,7 @@ def dequantize_blockwise(
|
|||
if A.device.type != 'cpu':
|
||||
device = pre_call(A.device)
|
||||
code = code.to(A.device)
|
||||
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
|
||||
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64, 32]:
|
||||
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
|
||||
is_on_gpu([A, absmax, out])
|
||||
if out.dtype == torch.float32:
|
||||
|
|
|
@ -2953,6 +2953,8 @@ template __global__ void kQuantizeBlockwise<half, 128, 2, 0, 0>(float * code, ha
|
|||
template __global__ void kQuantizeBlockwise<float, 128, 2, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 64, 2, 0, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 64, 2, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 32, 1, 0, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 32, 1, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
|
||||
template __global__ void kQuantizeBlockwise<half, 4096, 4, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 4096, 4, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
|
@ -2968,8 +2970,6 @@ template __global__ void kQuantizeBlockwise<half, 128, 2, 0, 1>(float * code, ha
|
|||
template __global__ void kQuantizeBlockwise<float, 128, 2, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 64, 2, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 64, 2, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
//template __global__ void kQuantizeBlockwise<half, 64, 1, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
//template __global__ void kQuantizeBlockwise<float, 64, 1, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
|
||||
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, 1>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, 1>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
||||
|
|
|
@ -71,8 +71,8 @@ template <typename T, int STOCHASTIC, int FP4> void quantizeBlockwise(float * co
|
|||
kQuantizeBlockwise<T, 128, 2, 0, FP4><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 64)
|
||||
kQuantizeBlockwise<T, 64, 2, 0, FP4><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
//else if(blocksize == 32)
|
||||
//kQuantizeBlockwise<T, 32, 1, 0, FP4><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 32 and FP4 == 0)
|
||||
kQuantizeBlockwise<T, 32, 1, 0, FP4><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
|
||||
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
|
|
|
@ -152,7 +152,7 @@ def test_dynamic_quantization():
|
|||
|
||||
def test_dynamic_blockwise_quantization():
|
||||
#print('')
|
||||
for blocksize in [4096, 2048, 1024, 512, 256, 128, 64]:
|
||||
for blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]:
|
||||
diffs = []
|
||||
reldiffs = []
|
||||
for i in range(100):
|
||||
|
@ -167,8 +167,8 @@ def test_dynamic_blockwise_quantization():
|
|||
relerr = sum(reldiffs)/len(reldiffs)
|
||||
assert abserr < 0.011
|
||||
assert relerr < 0.018
|
||||
#print('randn', blocksize, sum(diffs)/len(diffs))
|
||||
#print('randn', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
print('randn', blocksize, sum(diffs)/len(diffs))
|
||||
print('randn', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
|
||||
diffs = []
|
||||
for i in range(100):
|
||||
|
@ -184,8 +184,8 @@ def test_dynamic_blockwise_quantization():
|
|||
relerr = sum(reldiffs)/len(reldiffs)
|
||||
assert abserr < 0.0035
|
||||
assert relerr < 0.015
|
||||
#print('rand', blocksize, sum(diffs)/len(diffs))
|
||||
#print('rand', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
print('rand', blocksize, sum(diffs)/len(diffs))
|
||||
print('rand', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
|
||||
|
||||
def test_dynamic_blockwise_stochastic_quantization():
|
||||
|
|
Loading…
Reference in New Issue
Block a user