add memory efficient backward
This commit is contained in:
parent
579b8c782f
commit
591f60395a
|
@ -381,7 +381,6 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||
|
||||
elif state.CB is not None:
|
||||
raise NotImplementedError("WIP")
|
||||
CB = state.CB.to(ctx.dtype_B)
|
||||
CB.mul_(state.SCB.unsqueeze(1).div_(127.0).to(CB.dtype))
|
||||
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||
|
|
|
@ -14,13 +14,15 @@ class MockArgs(object):
|
|||
|
||||
|
||||
class MLP8bit(torch.nn.Module):
|
||||
def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0):
|
||||
def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
|
||||
super(MLP8bit, self).__init__()
|
||||
self.fc1 = bnb.nn.Linear8bitLt(
|
||||
dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold
|
||||
dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
|
||||
threshold=threshold
|
||||
)
|
||||
self.fc2 = bnb.nn.Linear8bitLt(
|
||||
dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold
|
||||
dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
|
||||
threshold=threshold
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -451,9 +453,12 @@ names = ["threshold_{0}".format(vals) for vals in values]
|
|||
|
||||
|
||||
@pytest.mark.parametrize("threshold", values, ids=names)
|
||||
def test_linear8bitlt_no_fp16_weights(threshold):
|
||||
@pytest.mark.parametrize("memory_efficient_backward", [True, False])
|
||||
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
|
||||
l1 = (
|
||||
bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False)
|
||||
bnb.nn.Linear8bitLt(
|
||||
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
|
||||
)
|
||||
.cuda()
|
||||
.half()
|
||||
)
|
||||
|
@ -513,7 +518,9 @@ def test_linear8bitlt_no_fp16_weights(threshold):
|
|||
assert mlp.fc2.weight.dtype == torch.int8
|
||||
|
||||
mlp = (
|
||||
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
|
||||
MLP8bit(
|
||||
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
|
||||
)
|
||||
.half()
|
||||
.to("cuda")
|
||||
)
|
||||
|
@ -532,7 +539,9 @@ def test_linear8bitlt_no_fp16_weights(threshold):
|
|||
assert mlp.fc2.weight.device.type == "cuda"
|
||||
|
||||
mlp = (
|
||||
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
|
||||
MLP8bit(
|
||||
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
|
||||
)
|
||||
.to(torch.float16)
|
||||
.to("cuda")
|
||||
)
|
||||
|
@ -551,6 +560,7 @@ def test_linear8bitlt_no_fp16_weights(threshold):
|
|||
assert mlp.fc2.weight.device.type == "cuda"
|
||||
|
||||
|
||||
|
||||
def test_linear8bitlt_fp32_bias():
|
||||
# casts model to fp16 -> int8 automatically
|
||||
l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda()
|
||||
|
|
Loading…
Reference in New Issue
Block a user