Fixed gradient accumulation test.
This commit is contained in:
parent
675baa79d2
commit
4bd1151829
|
@ -456,7 +456,6 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
|
||||
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
|
||||
if req_gradB:
|
||||
#grad_B = torch.matmul(grad_output.t(), A)
|
||||
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
|
||||
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
|
||||
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
|
||||
|
|
|
@ -332,12 +332,13 @@ def test_linear8bitlt_inference(threshold):
|
|||
def test_linear8bitlt_accumulated_gradient():
|
||||
l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)])
|
||||
l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
|
||||
l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
|
||||
l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
|
||||
l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
|
||||
l2[1].bias = torch.nn.Parameter(l1[1].bias.clone())
|
||||
opt1 = bnb.optim.Adam8bit(l1.parameters(), lr=0.001)
|
||||
opt2 = bnb.optim.Adam8bit(l2.parameters(), lr=0.001)
|
||||
l1[0].weight.data.copy_(l2[0].weight.data)
|
||||
l1[1].weight.data.copy_(l2[1].weight.data)
|
||||
l1[0].bias.data.copy_(l2[0].bias.data)
|
||||
l1[1].bias.data.copy_(l2[1].bias.data)
|
||||
|
||||
opt1 = bnb.optim.Adam32bit(l1.parameters(), lr=0.001)
|
||||
opt2 = bnb.optim.Adam32bit(l2.parameters(), lr=0.001)
|
||||
|
||||
acc_steps = 10
|
||||
|
||||
|
@ -353,7 +354,6 @@ def test_linear8bitlt_accumulated_gradient():
|
|||
assert l1[0].state.CxB is not None
|
||||
assert l1[1].state.CxB is not None
|
||||
|
||||
print(i)
|
||||
if i > 0 and i % acc_steps == 0:
|
||||
opt1.step()
|
||||
opt1.zero_grad(True)
|
||||
|
@ -368,9 +368,11 @@ def test_linear8bitlt_accumulated_gradient():
|
|||
# we do this copy because otherwise we have small divergences over time that add up
|
||||
l1[0].weight.data.copy_(l2[0].weight.data)
|
||||
l1[1].weight.data.copy_(l2[1].weight.data)
|
||||
l1[0].bias.data.copy_(l2[0].bias.data)
|
||||
l1[1].bias.data.copy_(l2[1].bias.data)
|
||||
else:
|
||||
torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad)
|
||||
torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad)
|
||||
torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad, atol=1e-3, rtol=1e-3)
|
||||
torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("threshold", [0.0, 2.0])
|
||||
|
|
Loading…
Reference in New Issue
Block a user