From ba51d95d433ef2cd10e1e4bf3e325d5b50004ff9 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 11 Jul 2023 05:55:49 -0700 Subject: [PATCH] Added more extensive gemv tests; blocksize guard for gemv. --- bitsandbytes/autograd/_functions.py | 8 +- bitsandbytes/functional.py | 1 + csrc/kernels.cu | 11 +- csrc/ops.cu | 1 + tests/test_functional.py | 170 +++++++++++++++++----------- 5 files changed, 122 insertions(+), 69 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 3b6016e..8a77e33 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -3,6 +3,7 @@ import warnings from dataclasses import dataclass from functools import reduce # Required in Python 3 from typing import Tuple, Optional, List +from warnings import warn import torch @@ -565,6 +566,11 @@ def matmul( def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None): assert quant_state is not None if A.numel() == A.shape[-1] and A.requires_grad == False: - return F.gemv_4bit(A, B.t(), out, state=quant_state) + absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state + if A.shape[-1] % blocksize != 0: + warn(f'Some matrices hidden dimension is not a multiple of {blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}') + return MatMul4Bit.apply(A, B, out, bias, quant_state) + else: + return F.gemv_4bit(A, B.t(), out, state=quant_state) else: return MatMul4Bit.apply(A, B, out, bias, quant_state) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 1972462..60a459c 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1504,6 +1504,7 @@ def gemv_4bit( lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) else: raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') + else: raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 051af63..883864f 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -222,6 +222,7 @@ __device__ half dhDequantizeNF4(unsigned char val) __device__ float dDequantizeNF4(unsigned char val) { + // the values for this tree was generated by test_normal_map_tree // in the file tests/test_functional.py if((val & 0b1000) == 8) @@ -3526,10 +3527,9 @@ template __global__ void kgemm_4bit_inferenc { // per threadblock: - // load step-by-step in chunks of [64,warps]: 1x64 * [64,warps] -> [1,warps] - // 64 packed 8bit values per warp AND each warp loads one output for B -> 1x128 * 128x1 + // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] // 4 warps -> 4 loads per iter - // 1x128 * 128x4 -> 1x4 outputs + // 1x32 * 32x4 -> 1x4 outputs per thread block typedef cub::WarpReduce WarpReduce; __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32]; @@ -3547,7 +3547,6 @@ template __global__ void kgemm_4bit_inferenc for(int i = threadIdx.x; i < 16; i++) quant_map[i] = T(datatype[i]); - __syncthreads(); // A: [1, K] @@ -3563,6 +3562,7 @@ template __global__ void kgemm_4bit_inferenc { if((inner_idx_halved + num_values_8bit) < (K/2)) { + // this is the most important for performance considerations reinterpret_cast(local_B_4bit)[0] = reinterpret_cast(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)]; } else @@ -3597,6 +3597,7 @@ template __global__ void kgemm_4bit_inferenc if(inner_idx+num_values_4bit < K) { + // this is also relatively important for performance if(BITS==16) { reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/4) + 0]; @@ -3618,6 +3619,7 @@ template __global__ void kgemm_4bit_inferenc } else + #pragma unroll for(int k = 0; k < num_values_4bit; k++) if(inner_idx + k < K) local_A[k] = A[inner_idx + k]; @@ -3625,6 +3627,7 @@ template __global__ void kgemm_4bit_inferenc local_A[k] = T(0.0f); + // accumulate in float; small performance hit for Ampere, but lower error for outputs #pragma unroll for(int k = 0; k < num_values_4bit; k++) { diff --git a/csrc/ops.cu b/csrc/ops.cu index b524e0e..9776121 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -735,6 +735,7 @@ template void gemm_4bit_inference_naive(int m, int n, int int num_blocks = (m+3)/4; kgemm_4bit_inference_naive<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } template void func(T *A, T *B, T value, long n) diff --git a/tests/test_functional.py b/tests/test_functional.py index 9bcc3fa..ae495f3 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2262,7 +2262,7 @@ def test_fp4_quant(dtype): A2 = F.dequantize_fp4(qa, SA) err = (A1 - A2).abs().float() - relerr = (err/A1.abs().float()).mean() + relerr = (err/(A1.abs().float()+1e-8)).mean() idx = err > 1.0 err = err.mean() @@ -2361,91 +2361,133 @@ def test_normal_map_tree(): @pytest.mark.parametrize("double_quant", [True, False], ids=['DQ_True', 'DQ_False']) @pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4']) +@pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'], ids=['fc1', 'fc2', 'attn', 'attn_packed']) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32']) -def test_gemv_4bit(dtype, storage_type, double_quant): - print('') - for dim in [128, 256, 512, 1024, 2048, 4096]: +def test_gemv_4bit(dtype, storage_type, double_quant, kind): + for dim in [128, 256, 512, 1024, 2048, 4096, 6144]: #for dim in [4*1024]: - #for dim in [1*16]: - errs = [] - relerrs = [] - max_err = 0 - max_relerr = 0 + #for dim in [1*128]: + errs1 = [] + errs2 = [] + errs3 = [] + relerrs1 = [] + relerrs2 = [] + relerrs3 = [] + max_errs1 = [] + max_errs2 = [] + max_errs3 = [] + for i in range(100): - #A = torch.rand(2, 4092, dtype=dtype, device='cuda') - #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') - #A = torch.rand(1, 4096, dtype=dtype, device='cuda') - #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') - A = torch.randn(1, dim, dtype=dtype, device='cuda') - #B = torch.randn(4, dim, dtype=dtype, device='cuda')/math.sqrt(dim) - B = torch.randn(dim*4, dim, dtype=dtype, device='cuda')/math.sqrt(dim) - - #print('') - #print(A) - #print(B.t()) - #A[:, :-1] = 0 - #B[:, :-1] = 0 - #A.flatten()[:-1] = 0 - #B.flatten()[:-1] = 0 + if kind == 'fc1': + A = torch.randn(1, dim, dtype=dtype, device='cuda') + B = torch.randn(dim*4, dim, dtype=dtype, device='cuda')/math.sqrt(dim) + elif kind == 'fc2': + A = torch.randn(1, 4*dim, dtype=dtype, device='cuda') + B = torch.randn(dim, 4*dim, dtype=dtype, device='cuda')/math.sqrt(dim) + elif kind == 'attn': + A = torch.randn(1, dim, dtype=dtype, device='cuda') + B = torch.randn(dim, dim, dtype=dtype, device='cuda')/math.sqrt(dim) + elif kind == 'attn_packed': + A = torch.randn(1, dim, dtype=dtype, device='cuda') + B = torch.randn(dim*3, dim, dtype=dtype, device='cuda')/math.sqrt(dim) qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant) - #F.dequantize_4bit(qB, state) - C3 = torch.matmul(A, B.t()) C2 = F.gemv_4bit(A, qB.t(), state=state) A.requires_grad = True C1 = bnb.matmul_4bit(A, qB.t(), state) - #print(state) - #print(qB) + err1 = (C1-C2).abs().float() + err2 = (C3-C2).abs().float() + err3 = (C3-C1).abs().float() - #print('') - #print(A) - #print(B) - #print('='*89) - #print(C3) + mag1 = torch.abs(C1).float()+1e-5 + mag2 = torch.abs(C3).float()+1e-5 + mag3 = torch.abs(C3).float()+1e-5 - #print(C1.shape, C2.shape) + relerr1 = err1/mag1 + relerr2 = err2/mag2 + relerr3 = err3/mag3 - # tensor cores are non-deterministic - # so we need to analyze errors around the mean - # to test our implementation - err = torch.abs(C1-C2).float() - mag = torch.abs(C1).float()+1e-5 - relerr = err/mag - max_err = max(err.max(), max_err) - max_relerr = max(relerr.max(), max_relerr) - err = err.mean().item() - relerr = relerr.mean().item() - #print(err) + max_err1 = err1.max() + max_err2 = err2.max() + max_err3 = err3.max() - errs.append(err) - relerrs.append(relerr) + errs1.append(err1.mean().item()) + errs2.append(err2.mean().item()) + errs3.append(err3.mean().item()) + + relerrs1.append(relerr1.mean().item()) + relerrs2.append(relerr2.mean().item()) + relerrs3.append(relerr3.mean().item()) + + max_errs1.append(max_err1.item()) + max_errs2.append(max_err2.item()) + max_errs3.append(max_err3.item()) c = int(C1.numel()*0.0014*(dim/256))+1 c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False) - #print('') - #print(dim, sum(errs)/len(errs)/math.sqrt(dim)) - #print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim)) - #print(dim, (max_err.item(), max_relerr.item())) - print(C1.flatten()[-20:]) - print(C2.flatten()[-20:]) - #print(C1.flatten()) - #print(C2.flatten()) - #print(C3.flatten()[-20:]) - print(sum(errs)/len(errs)/math.sqrt(dim) , dim) - print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , dim) + err1 = sum(errs1)/len(errs1)/math.sqrt(dim) + err2 = sum(errs2)/len(errs2)/math.sqrt(dim) + err3 = sum(errs3)/len(errs3)/math.sqrt(dim) + relerr1 = sum(relerrs1)/len(relerrs1)/math.sqrt(dim) + relerr2 = sum(relerrs2)/len(relerrs2)/math.sqrt(dim) + relerr3 = sum(relerrs3)/len(relerrs3)/math.sqrt(dim) + maxerr1 = sum(max_errs1)/len(max_errs1)/math.sqrt(dim) + maxerr2 = sum(max_errs2)/len(max_errs2)/math.sqrt(dim) + maxerr3 = sum(max_errs3)/len(max_errs3)/math.sqrt(dim) + absratio = err2/err3 + relratio = relerr2/relerr3 + maxratio = relerr2/relerr3 + + # for debugging if the tests fails + # + #print('='*80) + #print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:') + #print(C1.flatten()[-20:]) + #print(C2.flatten()[-20:]) + #print(f'inference vs training abs: {err1}') + #print(f'inference vs training rel: {relerr1}') + #print(f'inference vs training max: {maxerr1}') + #print(f'inference vs training vs torch err ratio abs: {absratio}') + #print(f'inference vs training vs torch err ratio rel: {relratio}') + #print(f'inference vs training vs torch err ratio max: {maxratio}') if dtype == torch.float16: - assert sum(errs)/len(errs)/math.sqrt(dim) < 5e-5 - assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.0005 + if dim <= 512: + assert err1 < 7e-5 + assert relerr1 < 0.0008 + else: + assert err1 < 6e-5 + assert relerr1 < 2e-4 + assert absratio < 1.005 and absratio > 0.995 + assert relratio < 1.005 and relratio > 0.995 + assert maxratio < 1.005 and maxratio > 0.995 elif dtype == torch.float32: - assert sum(errs)/len(errs)/math.sqrt(dim) < 5e-8 - assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 1e-7 + if dim <= 512: + assert err1 < 5e-8 + assert relerr1 < 1e-6 + assert maxerr1 < 1e-7 + else: + assert err1 < 5e-8 + assert relerr1 < 8e-6 + assert maxerr1 < 1e-7 + assert absratio < 1.005 and absratio > 0.995 + assert relratio < 1.005 and relratio > 0.995 + assert maxratio < 1.005 and maxratio > 0.995 elif dtype == torch.bfloat16: - assert sum(errs)/len(errs)/math.sqrt(dim) < 3e-4 - assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.003 + if dim <= 512: + assert err1 < 5e-4 + assert relerr1 < 0.007 + assert maxerr1 < 0.015 + else: + assert err1 < 2e-4 + assert relerr1 < 0.002 + assert maxerr1 < 0.0012 + assert absratio < 1.005 and absratio > 0.995 + assert relratio < 1.04 and relratio > 0.96 + assert maxratio < 1.02 and maxratio > 0.98 @pytest.mark.skip("Row scale has some bugs for ampere") def test_managed():