diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index acd90f5..63b7156 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -456,7 +456,6 @@ class MatMul8bitLt(torch.autograd.Function): Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) if req_gradB: - #grad_B = torch.matmul(grad_output.t(), A) CxAt, SAt = F.transform(CAt, formatB, transpose=True) C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) diff --git a/tests/test_modules.py b/tests/test_modules.py index 1319cf7..d0a9051 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -332,12 +332,13 @@ def test_linear8bitlt_inference(threshold): def test_linear8bitlt_accumulated_gradient(): l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]) l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]) - l2[0].weight = torch.nn.Parameter(l1[0].weight.clone()) - l2[0].bias = torch.nn.Parameter(l1[0].bias.clone()) - l2[1].weight = torch.nn.Parameter(l1[1].weight.clone()) - l2[1].bias = torch.nn.Parameter(l1[1].bias.clone()) - opt1 = bnb.optim.Adam8bit(l1.parameters(), lr=0.001) - opt2 = bnb.optim.Adam8bit(l2.parameters(), lr=0.001) + l1[0].weight.data.copy_(l2[0].weight.data) + l1[1].weight.data.copy_(l2[1].weight.data) + l1[0].bias.data.copy_(l2[0].bias.data) + l1[1].bias.data.copy_(l2[1].bias.data) + + opt1 = bnb.optim.Adam32bit(l1.parameters(), lr=0.001) + opt2 = bnb.optim.Adam32bit(l2.parameters(), lr=0.001) acc_steps = 10 @@ -353,7 +354,6 @@ def test_linear8bitlt_accumulated_gradient(): assert l1[0].state.CxB is not None assert l1[1].state.CxB is not None - print(i) if i > 0 and i % acc_steps == 0: opt1.step() opt1.zero_grad(True) @@ -368,9 +368,11 @@ def test_linear8bitlt_accumulated_gradient(): # we do this copy because otherwise we have small divergences over time that add up l1[0].weight.data.copy_(l2[0].weight.data) l1[1].weight.data.copy_(l2[1].weight.data) + l1[0].bias.data.copy_(l2[0].bias.data) + l1[1].bias.data.copy_(l2[1].bias.data) else: - torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad) - torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad) + torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad, atol=1e-3, rtol=1e-3) @pytest.mark.parametrize("threshold", [0.0, 2.0])