add memory efficient backward

This commit is contained in:
justheuristic 2022-09-18 00:52:53 +03:00
parent 579b8c782f
commit 591f60395a
2 changed files with 17 additions and 8 deletions

View File

@ -381,7 +381,6 @@ class MatMul8bitLt(torch.autograd.Function):
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CB is not None: elif state.CB is not None:
raise NotImplementedError("WIP")
CB = state.CB.to(ctx.dtype_B) CB = state.CB.to(ctx.dtype_B)
CB.mul_(state.SCB.unsqueeze(1).div_(127.0).to(CB.dtype)) CB.mul_(state.SCB.unsqueeze(1).div_(127.0).to(CB.dtype))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)

View File

@ -14,13 +14,15 @@ class MockArgs(object):
class MLP8bit(torch.nn.Module): class MLP8bit(torch.nn.Module):
def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0): def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
super(MLP8bit, self).__init__() super(MLP8bit, self).__init__()
self.fc1 = bnb.nn.Linear8bitLt( self.fc1 = bnb.nn.Linear8bitLt(
dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
threshold=threshold
) )
self.fc2 = bnb.nn.Linear8bitLt( self.fc2 = bnb.nn.Linear8bitLt(
dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
threshold=threshold
) )
def forward(self, x): def forward(self, x):
@ -451,9 +453,12 @@ names = ["threshold_{0}".format(vals) for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names) @pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_no_fp16_weights(threshold): @pytest.mark.parametrize("memory_efficient_backward", [True, False])
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
l1 = ( l1 = (
bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False) bnb.nn.Linear8bitLt(
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
)
.cuda() .cuda()
.half() .half()
) )
@ -513,7 +518,9 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8
mlp = ( mlp = (
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) MLP8bit(
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
)
.half() .half()
.to("cuda") .to("cuda")
) )
@ -532,7 +539,9 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc2.weight.device.type == "cuda" assert mlp.fc2.weight.device.type == "cuda"
mlp = ( mlp = (
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) MLP8bit(
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
)
.to(torch.float16) .to(torch.float16)
.to("cuda") .to("cuda")
) )
@ -551,6 +560,7 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc2.weight.device.type == "cuda" assert mlp.fc2.weight.device.type == "cuda"
def test_linear8bitlt_fp32_bias(): def test_linear8bitlt_fp32_bias():
# casts model to fp16 -> int8 automatically # casts model to fp16 -> int8 automatically
l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda() l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda()