Added double quantization support and tests.

This commit is contained in:
Tim Dettmers 2023-07-09 15:32:03 -07:00
parent 94168d79d7
commit 0f0390acb2
2 changed files with 38 additions and 18 deletions

View File

@ -1461,16 +1461,25 @@ def gemv_4bit(
transposed_B=False,
state=None
):
prev_device = pre_call(A.device)
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
if state is None:
Bshape = B.shape
bout = Bshape[1]
else:
Bshape = state[1]
bout = Bshape[0]
raise ValueError(f'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )')
Bshape = state[1]
bout = Bshape[0]
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = state
if compressed_stats is not None:
offset, state2 = compressed_stats
absmax = dequantize_blockwise(absmax, state2)
absmax += offset
if out is None:
out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device)
sA = A.shape
sB = B.shape
if transposed_A and len(sA) == 2:
@ -1557,14 +1566,16 @@ def gemv_4bit(
if B.dtype == torch.uint8:
if A.dtype == torch.float16:
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
elif A.dtype == torch.bfloat16:
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
else:
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
else:
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
post_call(prev_device)
return out

View File

@ -1776,7 +1776,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
print("partial matmul", time.time() - t0)
batch_size = 1
batch_size = 5
seqdim = 1
values = []
#values.append((batch_size, seqdim, 768, 4 * 768))
@ -1786,8 +1786,8 @@ values = []
#values.append((batch_size, seqdim, 2560, 4*2560))
#values.append((batch_size, seqdim, 4096, 4*4096))
#values.append((batch_size, seqdim, 5120, 4*5120))
#values.append((batch_size, seqdim, 6656, 4*6656))
values.append((batch_size, seqdim, 8192, 4*8192))
values.append((batch_size, seqdim, 6656, 4*6656))
#values.append((batch_size, seqdim, 8192, 4*8192))
#values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 12288, 4*12288))
names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values]
@ -1804,6 +1804,7 @@ def test_bench_matmul(batch, seq, model, hidden):
B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True)
B_nf4, state_nf4 = F.quantize_nf4(B)
B_nf4_c, state_nf4_c = F.quantize_nf4(B, compress_statistics=True)
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half()
linear8bit.eval()
@ -1816,7 +1817,7 @@ def test_bench_matmul(batch, seq, model, hidden):
linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
F.gemv_4bit(A, B_nf4.t(), state=state_nf4)
bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
# warmup
for i in range(iters):
@ -1848,11 +1849,18 @@ def test_bench_matmul(batch, seq, model, hidden):
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
#bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
F.gemv_4bit(A, B_nf4.t(), state=state_nf4)
bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
torch.cuda.synchronize()
print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c)
torch.cuda.synchronize()
print( f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
@ -2351,11 +2359,12 @@ def test_normal_map_tree():
print(pivots)
@pytest.mark.parametrize("double_quant", [True, False], ids=['DQ_True', 'DQ_False'])
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=['fp16', 'bf16'])
#@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16'])
def test_gemv_4bit(dtype, storage_type):
def test_gemv_4bit(dtype, storage_type, double_quant):
print('')
for dim in [128, 256, 512, 1024, 2048, 4096]:
#for dim in [4*1024]:
@ -2365,7 +2374,7 @@ def test_gemv_4bit(dtype, storage_type):
max_err = 0
max_relerr = 0
for i in range(1):
for i in range(100):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
@ -2382,11 +2391,11 @@ def test_gemv_4bit(dtype, storage_type):
#A.flatten()[:-1] = 0
#B.flatten()[:-1] = 0
qB, state = F.quantize_4bit(B, quant_type=storage_type)
F.dequantize_4bit(qB, state)
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
#F.dequantize_4bit(qB, state)
C2 = F.gemv_4bit(A, qB.t(), state=state)
C3 = torch.matmul(A, B.t())
C2 = F.gemv_4bit(A, qB.t(), state=state)
A.requires_grad = True
C1 = bnb.matmul_4bit(A, qB.t(), state)