add memory efficient backward

This commit is contained in:
dbaranchuk 2022-08-23 23:39:54 +03:00
parent 9d60b3c527
commit 8ae9bb23ad
2 changed files with 28 additions and 24 deletions

View File

@ -245,11 +245,10 @@ class MatMul8bitLt(torch.autograd.Function):
subA = A[:, idx] subA = A[:, idx]
state.subB = B[:, idx].t().contiguous() state.subB = B[:, idx].t().contiguous()
state.idx = idx state.idx = idx
else: elif state.CxB is None:
if state.CxB is None: # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions # we also need to convert it to the turing/ampere format
# we also need to convert it to the turing/ampere format state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
else: else:
if not state.has_fp16_weights and state.CxB is None: if not state.has_fp16_weights and state.CxB is None:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB) state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
@ -280,12 +279,6 @@ class MatMul8bitLt(torch.autograd.Function):
outlier_idx = torch.unique(coo_tensorA.colidx) outlier_idx = torch.unique(coo_tensorA.colidx)
state.idx = outlier_idx state.idx = outlier_idx
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# else:
# state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
state.subB = ( state.subB = (
(outliers * state.SCB.view(-1, 1) / 127.0) (outliers * state.SCB.view(-1, 1) / 127.0)
@ -343,12 +336,9 @@ class MatMul8bitLt(torch.autograd.Function):
SCAt, idx = ctx.tensor_states SCAt, idx = ctx.tensor_states
formatB = ctx.formatB formatB = ctx.formatB
state = ctx.state state = ctx.state
assert (
state.has_fp16_weights
), "Backprop only supported for fp16 weights."
if len(grad_output.shape) == 3: if len(grad_output.shape) == 3:
grad_output = grad_output.view( grad_output = grad_output.reshape(
-1, grad_output.shape[-1] -1, grad_output.shape[-1]
).contiguous() ).contiguous()
@ -365,11 +355,20 @@ class MatMul8bitLt(torch.autograd.Function):
if req_gradA: if req_gradA:
C32grad, Sgrad = F.transform(Cgrad, "col32") C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None: if state.CxBt is None and state.has_fp16_weights:
state.CxBt, state.SBt = F.transform( CBt = state.CBt
state.CBt, to_order=formatB, transpose=True elif state.CxBt is None:
) assert state.CBt is None
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) CB = state.CB.half()
SCB = state.SCB.unsquezee(1).half()
SCBt = state.SCBt.unsquezee(1).half()
Bt = (CB * SCB).t().contiguous()
CBt = (Bt / SCBt).t().to(torch.int8)
CxBt, SBt = F.transform(
CBt, to_order=formatB, transpose=True
)
gradA32, SgradA32 = F.igemmlt(C32grad, CxBt, Sgrad, SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
if req_gradBias: if req_gradBias:

View File

@ -148,10 +148,12 @@ class Int8Params(torch.nn.Parameter):
has_fp16_weights=False, has_fp16_weights=False,
CB=None, CB=None,
SCB=None, SCB=None,
SCBt=None,
): ):
cls.has_fp16_weights = has_fp16_weights cls.has_fp16_weights = has_fp16_weights
cls.CB = None cls.CB = None
cls.SCB = None cls.SCB = None
cls.SCBt = None
if data is None: if data is None:
data = torch.empty(0) data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad) return torch.Tensor._make_subclass(cls, data, requires_grad)
@ -165,10 +167,10 @@ class Int8Params(torch.nn.Parameter):
B = self.data.contiguous().half().cuda(device) B = self.data.contiguous().half().cuda(device)
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
del CBt del CBt
del SCBt
self.data = CB self.data = CB
setattr(self, "CB", CB) setattr(self, "CB", CB)
setattr(self, "SCB", SCB) setattr(self, "SCB", SCB)
setattr(self, "SCBt", SCBt)
return self return self
@ -210,6 +212,7 @@ class Int8Params(torch.nn.Parameter):
) )
new_param.CB = self.CB new_param.CB = self.CB
new_param.SCB = self.SCB new_param.SCB = self.SCB
new_param.SCB = self.SCBt
return new_param return new_param
@ -240,8 +243,10 @@ class Linear8bitLt(nn.Linear):
def init_8bit_state(self): def init_8bit_state(self):
self.state.CB = self.weight.CB self.state.CB = self.weight.CB
self.state.SCB = self.weight.SCB self.state.SCB = self.weight.SCB
self.state.SCBt = self.weight.SCBt
self.weight.CB = None self.weight.CB = None
self.weight.SCB = None self.weight.SCB = None
self.weight.SCBt = None
def forward(self, x): def forward(self, x):
self.state.is_training = self.training self.state.is_training = self.training
@ -255,11 +260,11 @@ class Linear8bitLt(nn.Linear):
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
if not self.state.has_fp16_weights and self.state.CB is not None: # if not self.state.has_fp16_weights and self.state.CB is not None:
# we converted 8-bit row major to turing/ampere format in the first inference pass # we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight # we no longer need the row-major weight
del self.state.CB # del self.state.CB
self.weight.data = self.state.CxB # self.weight.data = self.state.CxB
return out return out