Added backprop test for Linear8bitLt and LinearFP4.

This commit is contained in:
Tim Dettmers 2023-02-05 06:49:54 -08:00
parent c0c352b379
commit 7f0773aede

View File

@ -375,7 +375,7 @@ def test_linear8bitlt_accumulated_gradient():
@pytest.mark.parametrize("threshold", [0.0, 2.0])
@pytest.mark.parametrize("memory_efficient_backward", [False])
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
l1 = ( bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half())
l1 = (bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half())
assert l1.weight.dtype == torch.int8
l1.eval()
@ -506,3 +506,41 @@ def test_linear_kbit_fp32_bias(module):
o1 = l1(b1)
assert l1.bias is None
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize("module", [bnb.nn.Linear8bitLt, bnb.nn.LinearFP4], ids=['Int8Lt', 'FP4'])
def test_kbit_backprop(module):
b = 17
dim1 = 37
dim2 = 83
ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 10)])
ref[1].weight.requires_grad = False
kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 10)])
kbit[0].weight.detach().copy_(ref[0].weight)
kbit[1].weight.detach().copy_(ref[1].weight)
kbit[0].bias.detach().copy_(ref[0].bias)
kbit[1].bias.detach().copy_(ref[1].bias)
ref = ref.half().cuda()
kbit = kbit.half().cuda()
for i in range(100):
batch = torch.randn(b, dim1).half().cuda()
out1 = ref(batch)
out2 = kbit(batch)
out1.mean().backward()
out2.mean().backward()
grad1 = ref[0].weight.grad
grad2 = kbit[0].weight.grad
bgrad1 = ref[0].bias.grad
bgrad2 = kbit[0].bias.grad
torch.testing.assert_allclose(grad1, grad2, atol=0.008, rtol=0.05)
torch.testing.assert_allclose(bgrad1, bgrad2, atol=0.008, rtol=0.05)
ref.zero_grad()
kbit.zero_grad()
assert kbit[0].weight.grad.sum().item() == 0
assert kbit[0].bias.grad.sum().item() == 0