From 0f0390acb2a6307c6a92bbef2ff095bd7cbcdc90 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 9 Jul 2023 15:32:03 -0700 Subject: [PATCH] Added double quantization support and tests. --- bitsandbytes/functional.py | 25 ++++++++++++++++++------- tests/test_functional.py | 31 ++++++++++++++++++++----------- 2 files changed, 38 insertions(+), 18 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 1f658ac..aa18925 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1461,16 +1461,25 @@ def gemv_4bit( transposed_B=False, state=None ): + prev_device = pre_call(A.device) #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) if state is None: - Bshape = B.shape - bout = Bshape[1] - else: - Bshape = state[1] - bout = Bshape[0] + raise ValueError(f'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )') + + Bshape = state[1] + bout = Bshape[0] + absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = state + if compressed_stats is not None: + offset, state2 = compressed_stats + absmax = dequantize_blockwise(absmax, state2) + absmax += offset + if out is None: out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device) + + + sA = A.shape sB = B.shape if transposed_A and len(sA) == 2: @@ -1557,14 +1566,16 @@ def gemv_4bit( if B.dtype == torch.uint8: if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) + lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) + lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) else: raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') else: raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') + post_call(prev_device) + return out diff --git a/tests/test_functional.py b/tests/test_functional.py index 68688ed..6dff784 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1776,7 +1776,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): print("partial matmul", time.time() - t0) -batch_size = 1 +batch_size = 5 seqdim = 1 values = [] #values.append((batch_size, seqdim, 768, 4 * 768)) @@ -1786,8 +1786,8 @@ values = [] #values.append((batch_size, seqdim, 2560, 4*2560)) #values.append((batch_size, seqdim, 4096, 4*4096)) #values.append((batch_size, seqdim, 5120, 4*5120)) -#values.append((batch_size, seqdim, 6656, 4*6656)) -values.append((batch_size, seqdim, 8192, 4*8192)) +values.append((batch_size, seqdim, 6656, 4*6656)) +#values.append((batch_size, seqdim, 8192, 4*8192)) #values.append((batch_size, seqdim, 5140, 4*5140)) #values.append((batch_size, seqdim, 12288, 4*12288)) names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values] @@ -1804,6 +1804,7 @@ def test_bench_matmul(batch, seq, model, hidden): B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True) B_nf4, state_nf4 = F.quantize_nf4(B) + B_nf4_c, state_nf4_c = F.quantize_nf4(B, compress_statistics=True) linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half() linear8bit.eval() @@ -1816,7 +1817,7 @@ def test_bench_matmul(batch, seq, model, hidden): linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() - F.gemv_4bit(A, B_nf4.t(), state=state_nf4) + bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4) # warmup for i in range(iters): @@ -1848,11 +1849,18 @@ def test_bench_matmul(batch, seq, model, hidden): torch.cuda.synchronize() t0 = time.time() for i in range(iters): - #bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4) - F.gemv_4bit(A, B_nf4.t(), state=state_nf4) + bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4) torch.cuda.synchronize() print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c) + torch.cuda.synchronize() + print( f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + + #torch.cuda.synchronize() #t0 = time.time() #for i in range(iters): @@ -2351,11 +2359,12 @@ def test_normal_map_tree(): print(pivots) +@pytest.mark.parametrize("double_quant", [True, False], ids=['DQ_True', 'DQ_False']) @pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4']) #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=['fp16', 'bf16']) #@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16']) -def test_gemv_4bit(dtype, storage_type): +def test_gemv_4bit(dtype, storage_type, double_quant): print('') for dim in [128, 256, 512, 1024, 2048, 4096]: #for dim in [4*1024]: @@ -2365,7 +2374,7 @@ def test_gemv_4bit(dtype, storage_type): max_err = 0 max_relerr = 0 - for i in range(1): + for i in range(100): #A = torch.rand(2, 4092, dtype=dtype, device='cuda') #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') #A = torch.rand(1, 4096, dtype=dtype, device='cuda') @@ -2382,11 +2391,11 @@ def test_gemv_4bit(dtype, storage_type): #A.flatten()[:-1] = 0 #B.flatten()[:-1] = 0 - qB, state = F.quantize_4bit(B, quant_type=storage_type) - F.dequantize_4bit(qB, state) + qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant) + #F.dequantize_4bit(qB, state) - C2 = F.gemv_4bit(A, qB.t(), state=state) C3 = torch.matmul(A, B.t()) + C2 = F.gemv_4bit(A, qB.t(), state=state) A.requires_grad = True C1 = bnb.matmul_4bit(A, qB.t(), state)