add memory efficient backward

This commit is contained in:
justheuristic 2022-09-18 00:52:53 +03:00
parent 579b8c782f
commit 591f60395a
2 changed files with 17 additions and 8 deletions

View File

@ -381,7 +381,6 @@ class MatMul8bitLt(torch.autograd.Function):
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CB is not None:
raise NotImplementedError("WIP")
CB = state.CB.to(ctx.dtype_B)
CB.mul_(state.SCB.unsqueeze(1).div_(127.0).to(CB.dtype))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)

View File

@ -14,13 +14,15 @@ class MockArgs(object):
class MLP8bit(torch.nn.Module):
def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0):
def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
super(MLP8bit, self).__init__()
self.fc1 = bnb.nn.Linear8bitLt(
dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold
dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
threshold=threshold
)
self.fc2 = bnb.nn.Linear8bitLt(
dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold
dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
threshold=threshold
)
def forward(self, x):
@ -451,9 +453,12 @@ names = ["threshold_{0}".format(vals) for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_no_fp16_weights(threshold):
@pytest.mark.parametrize("memory_efficient_backward", [True, False])
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
l1 = (
bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False)
bnb.nn.Linear8bitLt(
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
)
.cuda()
.half()
)
@ -513,7 +518,9 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc2.weight.dtype == torch.int8
mlp = (
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
MLP8bit(
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
)
.half()
.to("cuda")
)
@ -532,7 +539,9 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc2.weight.device.type == "cuda"
mlp = (
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
MLP8bit(
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
)
.to(torch.float16)
.to("cuda")
)
@ -551,6 +560,7 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc2.weight.device.type == "cuda"
def test_linear8bitlt_fp32_bias():
# casts model to fp16 -> int8 automatically
l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda()