From 94168d79d74174ee4ba7c183e2cfc7dacc89c939 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 9 Jul 2023 14:46:19 -0700 Subject: [PATCH] Added FP4 fast inference support. --- bitsandbytes/autograd/_functions.py | 4 ++-- bitsandbytes/functional.py | 3 +-- csrc/kernels.cu | 6 ++---- tests/test_functional.py | 17 +++++++++-------- 4 files changed, 14 insertions(+), 16 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 7848b7e..22f89b1 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -509,7 +509,7 @@ class MatMul4Bit(torch.autograd.Function): # 1. Dequantize # 2. MatmulnN - output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype).t(), bias) + output = torch.nn.functional.linear(A, F.dequantize_4bit(B, state).to(A.dtype).t(), bias) # 3. Save state ctx.state = state @@ -540,7 +540,7 @@ class MatMul4Bit(torch.autograd.Function): # not supported by PyTorch. TODO: create work-around #if req_gradB: grad_B = torch.matmul(grad_output.t(), A) - if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(grad_output.dtype).t()) + if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t()) return grad_A, grad_B, None, grad_bias, None diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index e09b267..1f658ac 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1459,8 +1459,7 @@ def gemv_4bit( out: Tensor = None, transposed_A=False, transposed_B=False, - state=None, - storage_type='nf4' + state=None ): #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) if state is None: diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 4131477..1aaeb22 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3546,7 +3546,8 @@ template __global__ void kgemm_4bit_inference_naive(in T local_absmax = T(0.0f); for(int i = threadIdx.x; i < 16; i++) - quant_map[i] = nf4_data[i]; + quant_map[i] = datatype[i]; + __syncthreads(); // A: [1, K] @@ -3580,9 +3581,6 @@ template __global__ void kgemm_4bit_inference_naive(in { local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax; local_B[k*2 + 1] = quant_map[local_B_4bit[k] & 0x0F]*local_absmax; - - //if(threadIdx.x == 0) - //printf("%f %f %f %f\n", (float)local_B[k*2], (float)dDequantizeNF4(local_B_4bit[k] >> 4)*(float)local_absmax, (float)local_B[k*2]- ((float)dDequantizeNF4(local_B_4bit[k] >> 4)*(float)local_absmax), (float)local_absmax); } if(inner_idx+num_values_4bit) diff --git a/tests/test_functional.py b/tests/test_functional.py index 54af27d..68688ed 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2351,12 +2351,13 @@ def test_normal_map_tree(): print(pivots) +@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4']) #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=['fp16', 'bf16']) #@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16']) -def test_gemv_4bit(dtype): +def test_gemv_4bit(dtype, storage_type): print('') - for dim in [64, 128, 256, 512, 1024, 2048, 4096]: + for dim in [128, 256, 512, 1024, 2048, 4096]: #for dim in [4*1024]: #for dim in [1*16]: errs = [] @@ -2364,7 +2365,7 @@ def test_gemv_4bit(dtype): max_err = 0 max_relerr = 0 - for i in range(100): + for i in range(1): #A = torch.rand(2, 4092, dtype=dtype, device='cuda') #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') #A = torch.rand(1, 4096, dtype=dtype, device='cuda') @@ -2381,8 +2382,8 @@ def test_gemv_4bit(dtype): #A.flatten()[:-1] = 0 #B.flatten()[:-1] = 0 - qB, state = F.quantize_nf4(B) - F.dequantize_nf4(qB, state) + qB, state = F.quantize_4bit(B, quant_type=storage_type) + F.dequantize_4bit(qB, state) C2 = F.gemv_4bit(A, qB.t(), state=state) C3 = torch.matmul(A, B.t()) @@ -2396,7 +2397,6 @@ def test_gemv_4bit(dtype): #print(A) #print(B) #print('='*89) - #print(C3.flatten()[-20:]) #print(C3) #print(C1.shape, C2.shape) @@ -2425,8 +2425,9 @@ def test_gemv_4bit(dtype): #print(dim, (max_err.item(), max_relerr.item())) print(C1.flatten()[-20:]) print(C2.flatten()[-20:]) - print(sum(errs)/len(errs)/math.sqrt(dim) , 0.00015) - print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , 0.0015) + print(C3.flatten()[-20:]) + print(sum(errs)/len(errs)/math.sqrt(dim) , dim) + print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , dim) if dtype == torch.float16: assert sum(errs)/len(errs)/math.sqrt(dim) < 5e-5 assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.0005