Added double quantization support and tests.
This commit is contained in:
parent
94168d79d7
commit
0f0390acb2
|
@ -1461,16 +1461,25 @@ def gemv_4bit(
|
|||
transposed_B=False,
|
||||
state=None
|
||||
):
|
||||
prev_device = pre_call(A.device)
|
||||
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
|
||||
if state is None:
|
||||
Bshape = B.shape
|
||||
bout = Bshape[1]
|
||||
else:
|
||||
Bshape = state[1]
|
||||
bout = Bshape[0]
|
||||
raise ValueError(f'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )')
|
||||
|
||||
Bshape = state[1]
|
||||
bout = Bshape[0]
|
||||
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = state
|
||||
if compressed_stats is not None:
|
||||
offset, state2 = compressed_stats
|
||||
absmax = dequantize_blockwise(absmax, state2)
|
||||
absmax += offset
|
||||
|
||||
if out is None:
|
||||
out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device)
|
||||
|
||||
|
||||
|
||||
|
||||
sA = A.shape
|
||||
sB = B.shape
|
||||
if transposed_A and len(sA) == 2:
|
||||
|
@ -1557,14 +1566,16 @@ def gemv_4bit(
|
|||
|
||||
if B.dtype == torch.uint8:
|
||||
if A.dtype == torch.float16:
|
||||
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
||||
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
||||
elif A.dtype == torch.bfloat16:
|
||||
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
||||
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
||||
else:
|
||||
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
|
||||
else:
|
||||
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
|
||||
|
||||
post_call(prev_device)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
|
|
@ -1776,7 +1776,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
|
|||
print("partial matmul", time.time() - t0)
|
||||
|
||||
|
||||
batch_size = 1
|
||||
batch_size = 5
|
||||
seqdim = 1
|
||||
values = []
|
||||
#values.append((batch_size, seqdim, 768, 4 * 768))
|
||||
|
@ -1786,8 +1786,8 @@ values = []
|
|||
#values.append((batch_size, seqdim, 2560, 4*2560))
|
||||
#values.append((batch_size, seqdim, 4096, 4*4096))
|
||||
#values.append((batch_size, seqdim, 5120, 4*5120))
|
||||
#values.append((batch_size, seqdim, 6656, 4*6656))
|
||||
values.append((batch_size, seqdim, 8192, 4*8192))
|
||||
values.append((batch_size, seqdim, 6656, 4*6656))
|
||||
#values.append((batch_size, seqdim, 8192, 4*8192))
|
||||
#values.append((batch_size, seqdim, 5140, 4*5140))
|
||||
#values.append((batch_size, seqdim, 12288, 4*12288))
|
||||
names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values]
|
||||
|
@ -1804,6 +1804,7 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True)
|
||||
|
||||
B_nf4, state_nf4 = F.quantize_nf4(B)
|
||||
B_nf4_c, state_nf4_c = F.quantize_nf4(B, compress_statistics=True)
|
||||
|
||||
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half()
|
||||
linear8bit.eval()
|
||||
|
@ -1816,7 +1817,7 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
|
||||
linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
|
||||
linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
|
||||
F.gemv_4bit(A, B_nf4.t(), state=state_nf4)
|
||||
bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
|
||||
|
||||
# warmup
|
||||
for i in range(iters):
|
||||
|
@ -1848,11 +1849,18 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
#bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
|
||||
F.gemv_4bit(A, B_nf4.t(), state=state_nf4)
|
||||
bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
|
||||
torch.cuda.synchronize()
|
||||
print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c)
|
||||
torch.cuda.synchronize()
|
||||
print( f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
|
||||
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
|
@ -2351,11 +2359,12 @@ def test_normal_map_tree():
|
|||
print(pivots)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("double_quant", [True, False], ids=['DQ_True', 'DQ_False'])
|
||||
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
|
||||
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=['fp16', 'bf16'])
|
||||
#@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16'])
|
||||
def test_gemv_4bit(dtype, storage_type):
|
||||
def test_gemv_4bit(dtype, storage_type, double_quant):
|
||||
print('')
|
||||
for dim in [128, 256, 512, 1024, 2048, 4096]:
|
||||
#for dim in [4*1024]:
|
||||
|
@ -2365,7 +2374,7 @@ def test_gemv_4bit(dtype, storage_type):
|
|||
max_err = 0
|
||||
max_relerr = 0
|
||||
|
||||
for i in range(1):
|
||||
for i in range(100):
|
||||
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
|
||||
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
|
||||
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
|
||||
|
@ -2382,11 +2391,11 @@ def test_gemv_4bit(dtype, storage_type):
|
|||
#A.flatten()[:-1] = 0
|
||||
#B.flatten()[:-1] = 0
|
||||
|
||||
qB, state = F.quantize_4bit(B, quant_type=storage_type)
|
||||
F.dequantize_4bit(qB, state)
|
||||
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
|
||||
#F.dequantize_4bit(qB, state)
|
||||
|
||||
C2 = F.gemv_4bit(A, qB.t(), state=state)
|
||||
C3 = torch.matmul(A, B.t())
|
||||
C2 = F.gemv_4bit(A, qB.t(), state=state)
|
||||
A.requires_grad = True
|
||||
C1 = bnb.matmul_4bit(A, qB.t(), state)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user