Added 8-bit compression to quantization statistics.
This commit is contained in:
parent
c4cfe4fbdd
commit
51a21df728
|
@ -155,7 +155,7 @@ def create_linear_map(signed=True, total_bits=8, add_zero=True):
|
|||
#return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist())
|
||||
return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist())
|
||||
|
||||
def custom_map(seed=0, scale=0.01):
|
||||
def create_custom_map(seed=0, scale=0.01):
|
||||
v = [12, 10, 8, 6, 3, 2, 1]
|
||||
# 16-bit 7B 22.33, 4-bit best 22.88, FP4 23.25, 4-bit 95 22.97, 4-bit evo 22.45
|
||||
# 16-bit 13B 70.35, 4-bit best 67.16, FP4 100.78, 4-bit-95 69.39, 4-bit evo 70.48
|
||||
|
@ -191,13 +191,13 @@ def custom_map(seed=0, scale=0.01):
|
|||
# 13B evo start
|
||||
#v = [1.6077535089716468, 1.1914902148179205, 0.8999752421085561, 0.6967904489387543, 0.4949093928311768, 0.30920472033044544, 0.15391602735952042]
|
||||
#v = [1.586363722436466, 1.202610827188916, 0.9003332576346587, 0.6904888715206972, 0.49490974688233724, 0.2971151461329376, 0.15683230810738283]
|
||||
v = [1.5842247437829478, 1.2037228884260156, 0.900369059187269, 0.6898587137788914, 0.4949097822874533, 0.2959061887131868, 0.15712393618216908]
|
||||
#v = [1.5842247437829478, 1.2037228884260156, 0.900369059187269, 0.6898587137788914, 0.4949097822874533, 0.2959061887131868, 0.15712393618216908]
|
||||
|
||||
# mean evo 7B + 13B
|
||||
#v = [1.5993337549066253, 1.1965624035328402, 0.9000864380418481, 0.6925840978034195, 0.5011181210961458, 0.32040328389777434, 0.13570386022711237]
|
||||
|
||||
# theoretically optiomal (0.93333)
|
||||
# v = [1.501085946044025, 1.1331700302595604, 0.8761428492468408, 0.6670160135425023, 0.48373855304610314, 0.3155014472579608, 0.15580024666388428] # 0.9333333333333333
|
||||
v = [1.501085946044025, 1.1331700302595604, 0.8761428492468408, 0.6670160135425023, 0.48373855304610314, 0.3155014472579608, 0.15580024666388428] # 0.9333333333333333
|
||||
|
||||
|
||||
|
||||
|
@ -599,7 +599,9 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
|||
assert rand is None
|
||||
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
|
||||
|
||||
return out, (absmax, code)
|
||||
state = (absmax, code, blocksize)
|
||||
|
||||
return out, state
|
||||
|
||||
|
||||
def dequantize_blockwise(
|
||||
|
@ -644,9 +646,9 @@ def dequantize_blockwise(
|
|||
if out is None:
|
||||
out = torch.zeros_like(A, dtype=torch.float32)
|
||||
if quant_state is None:
|
||||
quant_state = (absmax, code)
|
||||
quant_state = (absmax, code, blocksize)
|
||||
else:
|
||||
absmax, code = quant_state
|
||||
absmax, code, blocksize = quant_state
|
||||
|
||||
|
||||
if A.device.type != 'cpu':
|
||||
|
@ -669,7 +671,7 @@ def dequantize_blockwise(
|
|||
return out
|
||||
|
||||
|
||||
def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64) -> Tensor:
|
||||
def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False) -> Tensor:
|
||||
"""
|
||||
Quantize tensor A in blocks of FP4 values.
|
||||
|
||||
|
@ -704,12 +706,11 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize
|
|||
blocks += 1 if n % blocksize > 0 else 0
|
||||
absmax = torch.zeros((blocks,), device=A.device)
|
||||
|
||||
state = (absmax, input_shape, A.dtype, blocksize)
|
||||
|
||||
if out is None:
|
||||
out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device)
|
||||
|
||||
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
|
||||
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]
|
||||
|
||||
prev_device = pre_call(A.device)
|
||||
is_on_gpu([A, out, absmax])
|
||||
|
@ -722,6 +723,17 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize
|
|||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
post_call(A.device)
|
||||
|
||||
if compress_statistics:
|
||||
offset = absmax.mean()
|
||||
absmax -= offset
|
||||
#code = create_custom_map().to(absmax.device)
|
||||
#qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256)
|
||||
qabsmax, state2 = quantize_blockwise(absmax, blocksize=256)
|
||||
del absmax
|
||||
state = (qabsmax, input_shape, A.dtype, blocksize, (offset, state2))
|
||||
else:
|
||||
state = (absmax, input_shape, A.dtype, blocksize, None)
|
||||
|
||||
return out, state
|
||||
|
||||
|
||||
|
@ -756,8 +768,12 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
|
|||
shape = out.shape
|
||||
dtype = out.dtype
|
||||
else:
|
||||
absmax, shape, dtype, blocksize = quant_state
|
||||
absmax, shape, dtype, blocksize, compressed_stats = quant_state
|
||||
|
||||
if compressed_stats is not None:
|
||||
offset, state2 = compressed_stats
|
||||
absmax = dequantize_blockwise(absmax, state2)
|
||||
absmax += offset
|
||||
|
||||
if out is None:
|
||||
out = torch.empty(shape, dtype=dtype, device=A.device)
|
||||
|
@ -1986,8 +2002,6 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
|
|||
ccolsB = ct.c_int32(B.shape[1])
|
||||
cldb = ct.c_int32(ldb)
|
||||
cldc = ct.c_int32(ldc)
|
||||
# print(cooA.rowidx[:64])
|
||||
# print(cooA.colidx[:64].sort()[0])
|
||||
|
||||
is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats])
|
||||
if B.dtype == torch.float16:
|
||||
|
|
|
@ -134,15 +134,17 @@ class Embedding(torch.nn.Embedding):
|
|||
return emb
|
||||
|
||||
class FP4Params(torch.nn.Parameter):
|
||||
def __new__(cls, data=None, requires_grad=True, quant_state=None):
|
||||
def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True):
|
||||
cls.quant_state = None
|
||||
cls.blocksize = blocksize
|
||||
cls.compress_statistics = compress_statistics
|
||||
if data is None:
|
||||
data = torch.empty(0)
|
||||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
|
||||
def cuda(self, device):
|
||||
w = self.data.contiguous().half().cuda(device)
|
||||
w_fp4, quant_state = bnb.functional.quantize_fp4(w)
|
||||
w_fp4, quant_state = bnb.functional.quantize_fp4(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics)
|
||||
self.data = w_fp4
|
||||
self.quant_state = quant_state
|
||||
|
||||
|
@ -173,10 +175,10 @@ class FP4Params(torch.nn.Parameter):
|
|||
|
||||
|
||||
class LinearFP4(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True, compute_dtype=None):
|
||||
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.state = bnb.MatmulLtState()
|
||||
self.weight = FP4Params(self.weight.data, requires_grad=False)
|
||||
self.weight = FP4Params(self.weight.data, requires_grad=False, compress_statistics=compress_statistics)
|
||||
self.compute_dtype = compute_dtype
|
||||
|
||||
def init_8bit_state(self):
|
||||
|
|
|
@ -454,14 +454,15 @@ for c in req_grad:
|
|||
transpose = [(False, True), (False, False)]
|
||||
str_transpose = ["NT", "NN"]
|
||||
dtype = [torch.float16, torch.float32]
|
||||
compress_statistics = [False, True]
|
||||
has_fp16_weights = [True, False]
|
||||
has_bias = [True, False]
|
||||
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias))
|
||||
str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias))
|
||||
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}".format(*vals) for vals in str_values]
|
||||
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics))
|
||||
str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias, compress_statistics))
|
||||
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}_compress_statistics".format(*vals) for vals in str_values]
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
||||
@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias", values, ids=names)
|
||||
def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias):
|
||||
@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics", values, ids=names)
|
||||
def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics):
|
||||
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
|
||||
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
|
||||
if has_bias == False:
|
||||
|
@ -481,7 +482,7 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
|
|||
bias2 = bias.clone()
|
||||
torch.nn.init.xavier_uniform_(B)
|
||||
|
||||
B2, quant_state = bnb.functional.quantize_fp4(B)
|
||||
B2, quant_state = bnb.functional.quantize_fp4(B, compress_statistics=compress_statistics)
|
||||
|
||||
if not transpose[0] and transpose[1]:
|
||||
out_torch = funcs[0](A, B.t())
|
||||
|
|
|
@ -167,8 +167,8 @@ def test_dynamic_blockwise_quantization():
|
|||
relerr = sum(reldiffs)/len(reldiffs)
|
||||
assert abserr < 0.011
|
||||
assert relerr < 0.018
|
||||
print('randn', blocksize, sum(diffs)/len(diffs))
|
||||
print('randn', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
#print('randn', blocksize, sum(diffs)/len(diffs))
|
||||
#print('randn', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
|
||||
diffs = []
|
||||
for i in range(100):
|
||||
|
@ -184,8 +184,8 @@ def test_dynamic_blockwise_quantization():
|
|||
relerr = sum(reldiffs)/len(reldiffs)
|
||||
assert abserr < 0.0035
|
||||
assert relerr < 0.015
|
||||
print('rand', blocksize, sum(diffs)/len(diffs))
|
||||
print('rand', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
#print('rand', blocksize, sum(diffs)/len(diffs))
|
||||
#print('rand', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
|
||||
|
||||
def test_dynamic_blockwise_stochastic_quantization():
|
||||
|
@ -1806,6 +1806,7 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
torch.nn.init.xavier_uniform_(B)
|
||||
|
||||
B_fp4, state = F.quantize_fp4(B)
|
||||
B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True)
|
||||
|
||||
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
|
||||
linear8bit.eval()
|
||||
|
@ -1839,6 +1840,13 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
torch.cuda.synchronize()
|
||||
print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
bnb.matmul_fp4(A, B_fp4.t(), quant_state=state_c)
|
||||
torch.cuda.synchronize()
|
||||
print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
|
@ -2244,6 +2252,42 @@ def test_fp4_quant():
|
|||
assert relerr.item() < 0.28
|
||||
|
||||
|
||||
def test_fp4_compressed_stats():
|
||||
for blocksize in [128, 64]:
|
||||
errs1 = []
|
||||
errs2 = []
|
||||
for i in range(10):
|
||||
A1 = torch.randn(1024, 1024, device='cuda').half()
|
||||
q2, SA2 = F.quantize_fp4(A1, blocksize=blocksize)
|
||||
q3, SA3= F.quantize_fp4(A1, blocksize=blocksize, compress_statistics=True)
|
||||
A2 = F.dequantize_fp4(q2, SA2)
|
||||
A3 = F.dequantize_fp4(q3, SA3)
|
||||
|
||||
|
||||
err = (A1 - A2).abs().float()
|
||||
relerr = (err/(A1.abs().float()+1e-15)).mean()
|
||||
err = err.mean()
|
||||
|
||||
errs1.append(err.item())
|
||||
|
||||
assert err.item() < 0.11
|
||||
assert relerr.item() < 0.28
|
||||
|
||||
err = (A1 - A3).abs().float()
|
||||
relerr = (err/(A1.abs().float()+1e-15)).mean()
|
||||
err = err.mean()
|
||||
|
||||
errs2.append(err.item())
|
||||
|
||||
assert err.item() < 0.11
|
||||
assert relerr.item() < 0.28
|
||||
|
||||
#print(sum(errs1)/len(errs1), blocksize)
|
||||
#print(sum(errs2)/len(errs2), blocksize)
|
||||
|
||||
|
||||
|
||||
|
||||
def test_bench_fp4_dequant():
|
||||
blocksize = 256
|
||||
a = torch.rand(1024*12*4, 1024*12, device='cuda').half()
|
||||
|
|
|
@ -507,7 +507,7 @@ def test_linear_kbit_fp32_bias(module):
|
|||
assert l1.bias is None
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
||||
@pytest.mark.parametrize("module", [bnb.nn.Linear8bitLt, bnb.nn.LinearFP4], ids=['Int8Lt', 'FP4'])
|
||||
@pytest.mark.parametrize("module", [bnb.nn.Linear8bitLt, bnb.nn.LinearFP4, lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True)], ids=['Int8Lt', 'FP4', 'FP4+C'])
|
||||
def test_kbit_backprop(module):
|
||||
b = 17
|
||||
dim1 = 37
|
||||
|
|
Loading…
Reference in New Issue
Block a user