Added backprop test for Linear8bitLt and LinearFP4.
This commit is contained in:
parent
c0c352b379
commit
7f0773aede
|
@ -375,7 +375,7 @@ def test_linear8bitlt_accumulated_gradient():
|
|||
@pytest.mark.parametrize("threshold", [0.0, 2.0])
|
||||
@pytest.mark.parametrize("memory_efficient_backward", [False])
|
||||
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
|
||||
l1 = ( bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half())
|
||||
l1 = (bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half())
|
||||
assert l1.weight.dtype == torch.int8
|
||||
|
||||
l1.eval()
|
||||
|
@ -506,3 +506,41 @@ def test_linear_kbit_fp32_bias(module):
|
|||
o1 = l1(b1)
|
||||
assert l1.bias is None
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
||||
@pytest.mark.parametrize("module", [bnb.nn.Linear8bitLt, bnb.nn.LinearFP4], ids=['Int8Lt', 'FP4'])
|
||||
def test_kbit_backprop(module):
|
||||
b = 17
|
||||
dim1 = 37
|
||||
dim2 = 83
|
||||
|
||||
ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 10)])
|
||||
ref[1].weight.requires_grad = False
|
||||
kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 10)])
|
||||
kbit[0].weight.detach().copy_(ref[0].weight)
|
||||
kbit[1].weight.detach().copy_(ref[1].weight)
|
||||
kbit[0].bias.detach().copy_(ref[0].bias)
|
||||
kbit[1].bias.detach().copy_(ref[1].bias)
|
||||
ref = ref.half().cuda()
|
||||
kbit = kbit.half().cuda()
|
||||
|
||||
for i in range(100):
|
||||
batch = torch.randn(b, dim1).half().cuda()
|
||||
out1 = ref(batch)
|
||||
out2 = kbit(batch)
|
||||
out1.mean().backward()
|
||||
out2.mean().backward()
|
||||
|
||||
grad1 = ref[0].weight.grad
|
||||
grad2 = kbit[0].weight.grad
|
||||
bgrad1 = ref[0].bias.grad
|
||||
bgrad2 = kbit[0].bias.grad
|
||||
|
||||
torch.testing.assert_allclose(grad1, grad2, atol=0.008, rtol=0.05)
|
||||
torch.testing.assert_allclose(bgrad1, bgrad2, atol=0.008, rtol=0.05)
|
||||
ref.zero_grad()
|
||||
kbit.zero_grad()
|
||||
|
||||
assert kbit[0].weight.grad.sum().item() == 0
|
||||
assert kbit[0].bias.grad.sum().item() == 0
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user