Fixed gradient accumulation test.
This commit is contained in:
parent
675baa79d2
commit
4bd1151829
|
@ -456,7 +456,6 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
|
|
||||||
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
|
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
|
||||||
if req_gradB:
|
if req_gradB:
|
||||||
#grad_B = torch.matmul(grad_output.t(), A)
|
|
||||||
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
|
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
|
||||||
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
|
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
|
||||||
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
|
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
|
||||||
|
|
|
@ -332,12 +332,13 @@ def test_linear8bitlt_inference(threshold):
|
||||||
def test_linear8bitlt_accumulated_gradient():
|
def test_linear8bitlt_accumulated_gradient():
|
||||||
l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)])
|
l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)])
|
||||||
l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
|
l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
|
||||||
l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
|
l1[0].weight.data.copy_(l2[0].weight.data)
|
||||||
l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
|
l1[1].weight.data.copy_(l2[1].weight.data)
|
||||||
l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
|
l1[0].bias.data.copy_(l2[0].bias.data)
|
||||||
l2[1].bias = torch.nn.Parameter(l1[1].bias.clone())
|
l1[1].bias.data.copy_(l2[1].bias.data)
|
||||||
opt1 = bnb.optim.Adam8bit(l1.parameters(), lr=0.001)
|
|
||||||
opt2 = bnb.optim.Adam8bit(l2.parameters(), lr=0.001)
|
opt1 = bnb.optim.Adam32bit(l1.parameters(), lr=0.001)
|
||||||
|
opt2 = bnb.optim.Adam32bit(l2.parameters(), lr=0.001)
|
||||||
|
|
||||||
acc_steps = 10
|
acc_steps = 10
|
||||||
|
|
||||||
|
@ -353,7 +354,6 @@ def test_linear8bitlt_accumulated_gradient():
|
||||||
assert l1[0].state.CxB is not None
|
assert l1[0].state.CxB is not None
|
||||||
assert l1[1].state.CxB is not None
|
assert l1[1].state.CxB is not None
|
||||||
|
|
||||||
print(i)
|
|
||||||
if i > 0 and i % acc_steps == 0:
|
if i > 0 and i % acc_steps == 0:
|
||||||
opt1.step()
|
opt1.step()
|
||||||
opt1.zero_grad(True)
|
opt1.zero_grad(True)
|
||||||
|
@ -368,9 +368,11 @@ def test_linear8bitlt_accumulated_gradient():
|
||||||
# we do this copy because otherwise we have small divergences over time that add up
|
# we do this copy because otherwise we have small divergences over time that add up
|
||||||
l1[0].weight.data.copy_(l2[0].weight.data)
|
l1[0].weight.data.copy_(l2[0].weight.data)
|
||||||
l1[1].weight.data.copy_(l2[1].weight.data)
|
l1[1].weight.data.copy_(l2[1].weight.data)
|
||||||
|
l1[0].bias.data.copy_(l2[0].bias.data)
|
||||||
|
l1[1].bias.data.copy_(l2[1].bias.data)
|
||||||
else:
|
else:
|
||||||
torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad)
|
torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad, atol=1e-3, rtol=1e-3)
|
||||||
torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad)
|
torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad, atol=1e-3, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("threshold", [0.0, 2.0])
|
@pytest.mark.parametrize("threshold", [0.0, 2.0])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user