forked from mrq/bitsandbytes-rocm
memory efficient fp16 backward
This commit is contained in:
parent
ef2936a90d
commit
4d6174bc63
|
@ -196,7 +196,6 @@ class MatmulLtState:
|
|||
|
||||
self.CxBt = None
|
||||
self.SBt = None
|
||||
self.CBt = None
|
||||
|
||||
|
||||
class MatMul8bitLt(torch.autograd.Function):
|
||||
|
@ -327,15 +326,12 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
#clone_func = torch.clone
|
||||
return clone_func(output.view(output_shape))
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.is_empty:
|
||||
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
|
||||
req_gradA, req_gradB, req_gradBias = ctx.req_grads
|
||||
CAt, subA = ctx.tensors
|
||||
SCAt, idx = ctx.tensor_states
|
||||
formatB = ctx.formatB
|
||||
assert not req_gradB, "TODO: support weight updates as well"
|
||||
state = ctx.state
|
||||
|
||||
if len(grad_output.shape) == 3:
|
||||
|
@ -345,37 +341,11 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
|
||||
grad_A = grad_B = grad_bias = None
|
||||
|
||||
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
|
||||
if req_gradB:
|
||||
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
|
||||
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
|
||||
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
|
||||
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
|
||||
if state.threshold > 0.0 and subA is not None:
|
||||
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
|
||||
|
||||
if req_gradA:
|
||||
C32grad, Sgrad = F.transform(Cgrad, "col32")
|
||||
if state.CxBt is None:
|
||||
if state.has_fp16_weights:
|
||||
CBt = state.CBt
|
||||
else:
|
||||
# Restore CBt from CB
|
||||
assert state.CBt is None, "CBt should not be stored in state"
|
||||
CB = state.CB.half()
|
||||
SCB = state.SCB.unsqueeze(1).half()
|
||||
SCBt = state.SCBt.unsqueeze(1).half()
|
||||
Bt = (CB * SCB).t().contiguous()
|
||||
CBt = (Bt / SCBt).t().to(torch.int8)
|
||||
|
||||
# intentionally, do not store CxBt in state
|
||||
CxBt, SBt = F.transform(
|
||||
CBt, to_order=formatB, transpose=True
|
||||
)
|
||||
else:
|
||||
CxBt = state.CxBt
|
||||
gradA32, SgradA32 = F.igemmlt(C32grad, CxBt, Sgrad, SBt)
|
||||
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
|
||||
CB = state.CB.half()
|
||||
SCB = state.SCB.unsqueeze(1).half()
|
||||
B = (CB * SCB) / 127.0
|
||||
grad_A = torch.mm(grad_output, B).view(ctx.grad_shape)
|
||||
|
||||
if req_gradBias:
|
||||
grad_bias = grad_output.sum(0)
|
||||
|
|
|
@ -148,12 +148,10 @@ class Int8Params(torch.nn.Parameter):
|
|||
has_fp16_weights=False,
|
||||
CB=None,
|
||||
SCB=None,
|
||||
SCBt=None,
|
||||
):
|
||||
cls.has_fp16_weights = has_fp16_weights
|
||||
cls.CB = None
|
||||
cls.SCB = None
|
||||
cls.SCBt = None
|
||||
if data is None:
|
||||
data = torch.empty(0)
|
||||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
|
@ -167,10 +165,10 @@ class Int8Params(torch.nn.Parameter):
|
|||
B = self.data.contiguous().half().cuda(device)
|
||||
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
|
||||
del CBt
|
||||
del SCBt
|
||||
self.data = CB
|
||||
setattr(self, "CB", CB)
|
||||
setattr(self, "SCB", SCB)
|
||||
setattr(self, "SCBt", SCBt)
|
||||
|
||||
return self
|
||||
|
||||
|
@ -212,7 +210,6 @@ class Int8Params(torch.nn.Parameter):
|
|||
)
|
||||
new_param.CB = self.CB
|
||||
new_param.SCB = self.SCB
|
||||
new_param.SCBt = self.SCBt
|
||||
|
||||
return new_param
|
||||
|
||||
|
@ -243,10 +240,8 @@ class Linear8bitLt(nn.Linear):
|
|||
def init_8bit_state(self):
|
||||
self.state.CB = self.weight.CB
|
||||
self.state.SCB = self.weight.SCB
|
||||
self.state.SCBt = self.weight.SCBt
|
||||
self.weight.CB = None
|
||||
self.weight.SCB = None
|
||||
self.weight.SCBt = None
|
||||
|
||||
def forward(self, x):
|
||||
self.state.is_training = self.training
|
||||
|
|
Loading…
Reference in New Issue
Block a user