add memory efficient backward
This commit is contained in:
parent
579b8c782f
commit
591f60395a
|
@ -381,7 +381,6 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
|
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||||
|
|
||||||
elif state.CB is not None:
|
elif state.CB is not None:
|
||||||
raise NotImplementedError("WIP")
|
|
||||||
CB = state.CB.to(ctx.dtype_B)
|
CB = state.CB.to(ctx.dtype_B)
|
||||||
CB.mul_(state.SCB.unsqueeze(1).div_(127.0).to(CB.dtype))
|
CB.mul_(state.SCB.unsqueeze(1).div_(127.0).to(CB.dtype))
|
||||||
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
|
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||||
|
|
|
@ -14,13 +14,15 @@ class MockArgs(object):
|
||||||
|
|
||||||
|
|
||||||
class MLP8bit(torch.nn.Module):
|
class MLP8bit(torch.nn.Module):
|
||||||
def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0):
|
def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
|
||||||
super(MLP8bit, self).__init__()
|
super(MLP8bit, self).__init__()
|
||||||
self.fc1 = bnb.nn.Linear8bitLt(
|
self.fc1 = bnb.nn.Linear8bitLt(
|
||||||
dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold
|
dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
|
||||||
|
threshold=threshold
|
||||||
)
|
)
|
||||||
self.fc2 = bnb.nn.Linear8bitLt(
|
self.fc2 = bnb.nn.Linear8bitLt(
|
||||||
dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold
|
dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
|
||||||
|
threshold=threshold
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -451,9 +453,12 @@ names = ["threshold_{0}".format(vals) for vals in values]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("threshold", values, ids=names)
|
@pytest.mark.parametrize("threshold", values, ids=names)
|
||||||
def test_linear8bitlt_no_fp16_weights(threshold):
|
@pytest.mark.parametrize("memory_efficient_backward", [True, False])
|
||||||
|
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
|
||||||
l1 = (
|
l1 = (
|
||||||
bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False)
|
bnb.nn.Linear8bitLt(
|
||||||
|
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
|
||||||
|
)
|
||||||
.cuda()
|
.cuda()
|
||||||
.half()
|
.half()
|
||||||
)
|
)
|
||||||
|
@ -513,7 +518,9 @@ def test_linear8bitlt_no_fp16_weights(threshold):
|
||||||
assert mlp.fc2.weight.dtype == torch.int8
|
assert mlp.fc2.weight.dtype == torch.int8
|
||||||
|
|
||||||
mlp = (
|
mlp = (
|
||||||
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
|
MLP8bit(
|
||||||
|
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
|
||||||
|
)
|
||||||
.half()
|
.half()
|
||||||
.to("cuda")
|
.to("cuda")
|
||||||
)
|
)
|
||||||
|
@ -532,7 +539,9 @@ def test_linear8bitlt_no_fp16_weights(threshold):
|
||||||
assert mlp.fc2.weight.device.type == "cuda"
|
assert mlp.fc2.weight.device.type == "cuda"
|
||||||
|
|
||||||
mlp = (
|
mlp = (
|
||||||
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
|
MLP8bit(
|
||||||
|
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
|
||||||
|
)
|
||||||
.to(torch.float16)
|
.to(torch.float16)
|
||||||
.to("cuda")
|
.to("cuda")
|
||||||
)
|
)
|
||||||
|
@ -551,6 +560,7 @@ def test_linear8bitlt_no_fp16_weights(threshold):
|
||||||
assert mlp.fc2.weight.device.type == "cuda"
|
assert mlp.fc2.weight.device.type == "cuda"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_linear8bitlt_fp32_bias():
|
def test_linear8bitlt_fp32_bias():
|
||||||
# casts model to fp16 -> int8 automatically
|
# casts model to fp16 -> int8 automatically
|
||||||
l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda()
|
l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user