Fixed gradient accumulation test.

This commit is contained in:
Tim Dettmers 2023-05-07 15:06:17 -07:00
parent 675baa79d2
commit 4bd1151829
2 changed files with 11 additions and 10 deletions

View File

@ -456,7 +456,6 @@ class MatMul8bitLt(torch.autograd.Function):
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
if req_gradB:
#grad_B = torch.matmul(grad_output.t(), A)
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)

View File

@ -332,12 +332,13 @@ def test_linear8bitlt_inference(threshold):
def test_linear8bitlt_accumulated_gradient():
l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)])
l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
l2[1].bias = torch.nn.Parameter(l1[1].bias.clone())
opt1 = bnb.optim.Adam8bit(l1.parameters(), lr=0.001)
opt2 = bnb.optim.Adam8bit(l2.parameters(), lr=0.001)
l1[0].weight.data.copy_(l2[0].weight.data)
l1[1].weight.data.copy_(l2[1].weight.data)
l1[0].bias.data.copy_(l2[0].bias.data)
l1[1].bias.data.copy_(l2[1].bias.data)
opt1 = bnb.optim.Adam32bit(l1.parameters(), lr=0.001)
opt2 = bnb.optim.Adam32bit(l2.parameters(), lr=0.001)
acc_steps = 10
@ -353,7 +354,6 @@ def test_linear8bitlt_accumulated_gradient():
assert l1[0].state.CxB is not None
assert l1[1].state.CxB is not None
print(i)
if i > 0 and i % acc_steps == 0:
opt1.step()
opt1.zero_grad(True)
@ -368,9 +368,11 @@ def test_linear8bitlt_accumulated_gradient():
# we do this copy because otherwise we have small divergences over time that add up
l1[0].weight.data.copy_(l2[0].weight.data)
l1[1].weight.data.copy_(l2[1].weight.data)
l1[0].bias.data.copy_(l2[0].bias.data)
l1[1].bias.data.copy_(l2[1].bias.data)
else:
torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad)
torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad)
torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad, atol=1e-3, rtol=1e-3)
torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad, atol=1e-3, rtol=1e-3)
@pytest.mark.parametrize("threshold", [0.0, 2.0])