From 90b0ac57b0d8d8f996126deb8bba6b7dc75b4327 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 11 Jul 2023 17:13:33 -0700 Subject: [PATCH] Fixed missing bias in bnb.matmul_4bit for inference; more tests. --- bitsandbytes/autograd/_functions.py | 5 ++++- bitsandbytes/functional.py | 2 -- tests/test_functional.py | 30 ++++++++++++++++++++++++++++- tests/test_generation.py | 21 ++++++++++++-------- 4 files changed, 46 insertions(+), 12 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 8a77e33..f2fdb7d 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -571,6 +571,9 @@ def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bia warn(f'Some matrices hidden dimension is not a multiple of {blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}') return MatMul4Bit.apply(A, B, out, bias, quant_state) else: - return F.gemv_4bit(A, B.t(), out, state=quant_state) + out = F.gemv_4bit(A, B.t(), out, state=quant_state) + if bias is not None: + out += bias + return out else: return MatMul4Bit.apply(A, B, out, bias, quant_state) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 60a459c..033ae32 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1512,8 +1512,6 @@ def gemv_4bit( return out - - def igemm( A: Tensor, B: Tensor, diff --git a/tests/test_functional.py b/tests/test_functional.py index ae495f3..3c891a3 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2364,7 +2364,7 @@ def test_normal_map_tree(): @pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'], ids=['fc1', 'fc2', 'attn', 'attn_packed']) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32']) def test_gemv_4bit(dtype, storage_type, double_quant, kind): - for dim in [128, 256, 512, 1024, 2048, 4096, 6144]: + for dim in [128, 256, 512, 1024]: #for dim in [4*1024]: #for dim in [1*128]: errs1 = [] @@ -2525,3 +2525,31 @@ def test_managed(): # assert (A==17).sum().item() == n*n # torch.testing.assert_close(A, torch.ones(A.shape)*289) + + +@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4']) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32']) +@pytest.mark.parametrize("double_quant", [False], ids=['DQ_True']) +def test_gemv_eye_4bit(storage_type, dtype, double_quant): + dims = 10 + torch.random.manual_seed(np.random.randint(0, 412424242)) + dims = torch.randint(0, 8192, size=(dims,)).tolist() + dims = [dim + (64-(dim % 64)) for dim in dims] + #for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]: + for dim in dims: + A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device='cuda') + B = torch.eye(dim, dtype=dtype, device='cuda') + + qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant) + C3 = torch.matmul(A, B.t()) + C2 = bnb.matmul_4bit(A, qB.t(), state) + A.requires_grad = True + C1 = bnb.matmul_4bit(A, qB.t(), state) + + torch.testing.assert_close(A, C3) + torch.testing.assert_close(A, C1) + torch.testing.assert_close(A, C2) + #torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001) + #torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080) + + diff --git a/tests/test_generation.py b/tests/test_generation.py index 3159a69..b4c1a8c 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -65,7 +65,7 @@ def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_f return tokenizer.decode(outputs[0], skip_special_tokens=True) models = ['huggyllama/llama-7b', 'bigscience/bloom-1b7'] -dtypes = ['nf4', 'fp4', '16bit'] +dtypes = ['nf4', 'fp4'] load_in_4bit = [True, False] values = list(product(models, dtypes)) strfunc = lambda lst: [str(x) for x in lst] @@ -73,14 +73,17 @@ ids = ['_'.join(strfunc(x)) for x in values] @pytest.fixture(scope='session', params=values, ids=ids) def model_and_tokenizer(request): model, tokenizer = get_model_and_tokenizer(request.param) - yield model, tokenizer + yield request.param, model, tokenizer del model +@pytest.mark.parametrize("DQ", [True, False], ids=['DQ_True', 'DQ_False']) @pytest.mark.parametrize("inference_kernel", [True, False], ids=['inference_kernel_True', 'inference_kernel_False']) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32']) -def test_pi(model_and_tokenizer, dtype, inference_kernel): +#@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32']) +def test_pi(model_and_tokenizer, inference_kernel, DQ): + print('') + dtype = torch.float16 - model, tokenizer = model_and_tokenizer + fixture_config, model, tokenizer = model_and_tokenizer generation_config = transformers.GenerationConfig( max_new_tokens=20, @@ -94,16 +97,16 @@ def test_pi(model_and_tokenizer, dtype, inference_kernel): #text = 'Please write down the first 50 digits of pi.' #text = get_prompt_for_generation_eval(text) #text += ' Sure, here the first 50 digits of pi: 3.14159' - n_cases = 3 + n_cases = 6 text = '3.14159' if hasattr(model.config, 'quantization_config'): model.config.quantization_config.bnb_4bit_compute_dtype = dtype + model.config.quantization_config.bnb_4bit_use_double_quant = DQ if not inference_kernel: text = [text]*n_cases inputs = tokenizer(text, return_tensors="pt").to('cuda:0') x = inputs['input_ids'] - failure_count = 0 outputs = [] if inference_kernel: for i in range(n_cases): @@ -116,10 +119,12 @@ def test_pi(model_and_tokenizer, dtype, inference_kernel): assert len(outputs) == n_cases + failure_count = 0 for i in range(n_cases): if not outputs[i][:len(str(math.pi))] == str(math.pi): failure_count += 1 - if failure_count > 1: + failure_max = (2 if fixture_config[0] == 'huggyllama/llama-7b' else 4) + if failure_count > failure_max: print(math.pi) for out in outputs: print(out)