From c0c352b3791a5aab14263108595479b9db58fa1f Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 5 Feb 2023 06:29:52 -0800 Subject: [PATCH] Added bias test for LinearFP4 and basic test. --- bitsandbytes/nn/__init__.py | 2 +- bitsandbytes/nn/modules.py | 6 +++--- tests/test_modules.py | 43 +++++++++++-------------------------- 3 files changed, 16 insertions(+), 35 deletions(-) diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index edc595a..79fb51e 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -2,4 +2,4 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .modules import Int8Params, Linear8bitLt, StableEmbedding +from .modules import Int8Params, Linear8bitLt, StableEmbedding, LinearFP4 diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 6dfb06c..4c719c6 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -188,9 +188,9 @@ class LinearFP4(nn.Linear): if self.bias is not None and self.bias.dtype != x.dtype: self.bias.data = self.bias.data.to(x.dtype) - if getattr(self.weight, 'state', None) is None: - print('FP4 state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') - out = bnb.matmul_fp(x, self.weight, bias=self.bias, state=self.weight.state) + if getattr(self.weight, 'quant_state', None) is None: + print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') + out = bnb.matmul_fp4(x, self.weight.t(), bias=self.bias, quant_state=self.weight.quant_state) return out diff --git a/tests/test_modules.py b/tests/test_modules.py index d78f0c9..ba67bfc 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -330,12 +330,8 @@ def test_linear8bitlt_inference(threshold): def test_linear8bitlt_accumulated_gradient(): - l1 = torch.nn.Sequential( - *[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)] - ) - l2 = torch.nn.Sequential( - *[torch.nn.Linear(32, 32).cuda().half() for i in range(2)] - ) + l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]) + l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]) l2[0].weight = torch.nn.Parameter(l1[0].weight.clone()) l2[0].bias = torch.nn.Parameter(l1[0].bias.clone()) l2[1].weight = torch.nn.Parameter(l1[1].weight.clone()) @@ -376,21 +372,10 @@ def test_linear8bitlt_accumulated_gradient(): torch.testing.assert_allclose(l1[1].weight.grad, l2[1].weight.grad) -threshold = [0.0, 2.0] -values = threshold -names = [f"threshold_{vals}" for vals in values] - - -@pytest.mark.parametrize("threshold", values, ids=names) +@pytest.mark.parametrize("threshold", [0.0, 2.0]) @pytest.mark.parametrize("memory_efficient_backward", [False]) def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): - l1 = ( - bnb.nn.Linear8bitLt( - 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward - ) - .cuda() - .half() - ) + l1 = ( bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half()) assert l1.weight.dtype == torch.int8 l1.eval() @@ -446,13 +431,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 - mlp = ( - MLP8bit( - 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward - ) - .half() - .to("cuda") - ) + mlp = ( MLP8bit( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).half().to("cuda")) for i in range(100): b1 = torch.randn(16, 8, 32, device="cuda").half() @@ -504,10 +483,11 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): assert (idx == 0).sum().item() <= b1.numel() * 0.005 -def test_linear8bitlt_fp32_bias(): +@pytest.mark.parametrize("module", [lambda nin, nout, bias=True: bnb.nn.Linear8bitLt(nin, nout, bias=bias, has_fp16_weights=False), bnb.nn.LinearFP4], ids=['Int8Lt', 'FP4']) +def test_linear_kbit_fp32_bias(module): # casts model to fp16 -> int8 automatically - l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda() - assert l1.weight.dtype == torch.int8 + l1 = module(32, 64).cuda() + assert l1.weight.dtype in [torch.int8, torch.uint8] assert l1.bias.dtype == torch.float32 for i in range(100): @@ -517,11 +497,12 @@ def test_linear8bitlt_fp32_bias(): assert l1.bias.dtype == torch.float16 # casts model to fp16 -> int8 automatically - l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False, bias=False).cuda() - assert l1.weight.dtype == torch.int8 + l1 = module(32, 64, bias=False).cuda() + assert l1.weight.dtype in [torch.int8, torch.uint8] assert l1.bias is None for i in range(100): b1 = torch.randn(16, 8, 32, device="cuda").half() o1 = l1(b1) assert l1.bias is None +