From 7f0773aede92a8be5bf0645185de4f5707b3a2a8 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 5 Feb 2023 06:49:54 -0800 Subject: [PATCH] Added backprop test for Linear8bitLt and LinearFP4. --- tests/test_modules.py | 40 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/tests/test_modules.py b/tests/test_modules.py index ba67bfc..41cc050 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -375,7 +375,7 @@ def test_linear8bitlt_accumulated_gradient(): @pytest.mark.parametrize("threshold", [0.0, 2.0]) @pytest.mark.parametrize("memory_efficient_backward", [False]) def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): - l1 = ( bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half()) + l1 = (bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half()) assert l1.weight.dtype == torch.int8 l1.eval() @@ -506,3 +506,41 @@ def test_linear_kbit_fp32_bias(module): o1 = l1(b1) assert l1.bias is None +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +@pytest.mark.parametrize("module", [bnb.nn.Linear8bitLt, bnb.nn.LinearFP4], ids=['Int8Lt', 'FP4']) +def test_kbit_backprop(module): + b = 17 + dim1 = 37 + dim2 = 83 + + ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 10)]) + ref[1].weight.requires_grad = False + kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 10)]) + kbit[0].weight.detach().copy_(ref[0].weight) + kbit[1].weight.detach().copy_(ref[1].weight) + kbit[0].bias.detach().copy_(ref[0].bias) + kbit[1].bias.detach().copy_(ref[1].bias) + ref = ref.half().cuda() + kbit = kbit.half().cuda() + + for i in range(100): + batch = torch.randn(b, dim1).half().cuda() + out1 = ref(batch) + out2 = kbit(batch) + out1.mean().backward() + out2.mean().backward() + + grad1 = ref[0].weight.grad + grad2 = kbit[0].weight.grad + bgrad1 = ref[0].bias.grad + bgrad2 = kbit[0].bias.grad + + torch.testing.assert_allclose(grad1, grad2, atol=0.008, rtol=0.05) + torch.testing.assert_allclose(bgrad1, bgrad2, atol=0.008, rtol=0.05) + ref.zero_grad() + kbit.zero_grad() + + assert kbit[0].weight.grad.sum().item() == 0 + assert kbit[0].bias.grad.sum().item() == 0 + +