forked from mrq/bitsandbytes-rocm
add memory efficient backward
This commit is contained in:
parent
9d60b3c527
commit
8ae9bb23ad
|
@ -245,11 +245,10 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
subA = A[:, idx]
|
||||
state.subB = B[:, idx].t().contiguous()
|
||||
state.idx = idx
|
||||
else:
|
||||
if state.CxB is None:
|
||||
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
|
||||
# we also need to convert it to the turing/ampere format
|
||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
||||
elif state.CxB is None:
|
||||
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
|
||||
# we also need to convert it to the turing/ampere format
|
||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
||||
else:
|
||||
if not state.has_fp16_weights and state.CxB is None:
|
||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
||||
|
@ -280,12 +279,6 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
|
||||
outlier_idx = torch.unique(coo_tensorA.colidx)
|
||||
state.idx = outlier_idx
|
||||
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
|
||||
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
|
||||
# # do not use pool for 2nd FFN layer
|
||||
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
|
||||
# else:
|
||||
# state.idx = outlier_idx
|
||||
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
|
||||
state.subB = (
|
||||
(outliers * state.SCB.view(-1, 1) / 127.0)
|
||||
|
@ -343,12 +336,9 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
SCAt, idx = ctx.tensor_states
|
||||
formatB = ctx.formatB
|
||||
state = ctx.state
|
||||
assert (
|
||||
state.has_fp16_weights
|
||||
), "Backprop only supported for fp16 weights."
|
||||
|
||||
if len(grad_output.shape) == 3:
|
||||
grad_output = grad_output.view(
|
||||
grad_output = grad_output.reshape(
|
||||
-1, grad_output.shape[-1]
|
||||
).contiguous()
|
||||
|
||||
|
@ -365,11 +355,20 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
|
||||
if req_gradA:
|
||||
C32grad, Sgrad = F.transform(Cgrad, "col32")
|
||||
if state.CxBt is None:
|
||||
state.CxBt, state.SBt = F.transform(
|
||||
state.CBt, to_order=formatB, transpose=True
|
||||
)
|
||||
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
|
||||
if state.CxBt is None and state.has_fp16_weights:
|
||||
CBt = state.CBt
|
||||
elif state.CxBt is None:
|
||||
assert state.CBt is None
|
||||
CB = state.CB.half()
|
||||
SCB = state.SCB.unsquezee(1).half()
|
||||
SCBt = state.SCBt.unsquezee(1).half()
|
||||
Bt = (CB * SCB).t().contiguous()
|
||||
CBt = (Bt / SCBt).t().to(torch.int8)
|
||||
|
||||
CxBt, SBt = F.transform(
|
||||
CBt, to_order=formatB, transpose=True
|
||||
)
|
||||
gradA32, SgradA32 = F.igemmlt(C32grad, CxBt, Sgrad, SBt)
|
||||
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
|
||||
|
||||
if req_gradBias:
|
||||
|
|
|
@ -148,10 +148,12 @@ class Int8Params(torch.nn.Parameter):
|
|||
has_fp16_weights=False,
|
||||
CB=None,
|
||||
SCB=None,
|
||||
SCBt=None,
|
||||
):
|
||||
cls.has_fp16_weights = has_fp16_weights
|
||||
cls.CB = None
|
||||
cls.SCB = None
|
||||
cls.SCBt = None
|
||||
if data is None:
|
||||
data = torch.empty(0)
|
||||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
|
@ -165,10 +167,10 @@ class Int8Params(torch.nn.Parameter):
|
|||
B = self.data.contiguous().half().cuda(device)
|
||||
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
|
||||
del CBt
|
||||
del SCBt
|
||||
self.data = CB
|
||||
setattr(self, "CB", CB)
|
||||
setattr(self, "SCB", SCB)
|
||||
setattr(self, "SCBt", SCBt)
|
||||
|
||||
return self
|
||||
|
||||
|
@ -210,6 +212,7 @@ class Int8Params(torch.nn.Parameter):
|
|||
)
|
||||
new_param.CB = self.CB
|
||||
new_param.SCB = self.SCB
|
||||
new_param.SCB = self.SCBt
|
||||
|
||||
return new_param
|
||||
|
||||
|
@ -240,8 +243,10 @@ class Linear8bitLt(nn.Linear):
|
|||
def init_8bit_state(self):
|
||||
self.state.CB = self.weight.CB
|
||||
self.state.SCB = self.weight.SCB
|
||||
self.state.SCBt = self.weight.SCBt
|
||||
self.weight.CB = None
|
||||
self.weight.SCB = None
|
||||
self.weight.SCBt = None
|
||||
|
||||
def forward(self, x):
|
||||
self.state.is_training = self.training
|
||||
|
@ -255,11 +260,11 @@ class Linear8bitLt(nn.Linear):
|
|||
|
||||
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
||||
|
||||
if not self.state.has_fp16_weights and self.state.CB is not None:
|
||||
# if not self.state.has_fp16_weights and self.state.CB is not None:
|
||||
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
||||
# we no longer need the row-major weight
|
||||
del self.state.CB
|
||||
self.weight.data = self.state.CxB
|
||||
# del self.state.CB
|
||||
# self.weight.data = self.state.CxB
|
||||
|
||||
return out
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user