add memory efficient backward

This commit is contained in:
dbaranchuk 2022-08-23 23:39:54 +03:00
parent 9d60b3c527
commit 8ae9bb23ad
2 changed files with 28 additions and 24 deletions

View File

@ -245,11 +245,10 @@ class MatMul8bitLt(torch.autograd.Function):
subA = A[:, idx]
state.subB = B[:, idx].t().contiguous()
state.idx = idx
else:
if state.CxB is None:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
elif state.CxB is None:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
else:
if not state.has_fp16_weights and state.CxB is None:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
@ -280,12 +279,6 @@ class MatMul8bitLt(torch.autograd.Function):
outlier_idx = torch.unique(coo_tensorA.colidx)
state.idx = outlier_idx
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# else:
# state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
state.subB = (
(outliers * state.SCB.view(-1, 1) / 127.0)
@ -343,12 +336,9 @@ class MatMul8bitLt(torch.autograd.Function):
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state
assert (
state.has_fp16_weights
), "Backprop only supported for fp16 weights."
if len(grad_output.shape) == 3:
grad_output = grad_output.view(
grad_output = grad_output.reshape(
-1, grad_output.shape[-1]
).contiguous()
@ -365,11 +355,20 @@ class MatMul8bitLt(torch.autograd.Function):
if req_gradA:
C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None:
state.CxBt, state.SBt = F.transform(
state.CBt, to_order=formatB, transpose=True
)
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
if state.CxBt is None and state.has_fp16_weights:
CBt = state.CBt
elif state.CxBt is None:
assert state.CBt is None
CB = state.CB.half()
SCB = state.SCB.unsquezee(1).half()
SCBt = state.SCBt.unsquezee(1).half()
Bt = (CB * SCB).t().contiguous()
CBt = (Bt / SCBt).t().to(torch.int8)
CxBt, SBt = F.transform(
CBt, to_order=formatB, transpose=True
)
gradA32, SgradA32 = F.igemmlt(C32grad, CxBt, Sgrad, SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
if req_gradBias:

View File

@ -148,10 +148,12 @@ class Int8Params(torch.nn.Parameter):
has_fp16_weights=False,
CB=None,
SCB=None,
SCBt=None,
):
cls.has_fp16_weights = has_fp16_weights
cls.CB = None
cls.SCB = None
cls.SCBt = None
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad)
@ -165,10 +167,10 @@ class Int8Params(torch.nn.Parameter):
B = self.data.contiguous().half().cuda(device)
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
del CBt
del SCBt
self.data = CB
setattr(self, "CB", CB)
setattr(self, "SCB", SCB)
setattr(self, "SCBt", SCBt)
return self
@ -210,6 +212,7 @@ class Int8Params(torch.nn.Parameter):
)
new_param.CB = self.CB
new_param.SCB = self.SCB
new_param.SCB = self.SCBt
return new_param
@ -240,8 +243,10 @@ class Linear8bitLt(nn.Linear):
def init_8bit_state(self):
self.state.CB = self.weight.CB
self.state.SCB = self.weight.SCB
self.state.SCBt = self.weight.SCBt
self.weight.CB = None
self.weight.SCB = None
self.weight.SCBt = None
def forward(self, x):
self.state.is_training = self.training
@ -255,11 +260,11 @@ class Linear8bitLt(nn.Linear):
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
if not self.state.has_fp16_weights and self.state.CB is not None:
# if not self.state.has_fp16_weights and self.state.CB is not None:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del self.state.CB
self.weight.data = self.state.CxB
# del self.state.CB
# self.weight.data = self.state.CxB
return out