forked from mrq/bitsandbytes-rocm
add memory effcient backward option
This commit is contained in:
parent
843ad0631c
commit
42b5fc9acc
|
@ -1,5 +1,6 @@
|
||||||
import operator
|
import operator
|
||||||
import torch
|
import torch
|
||||||
|
import bitsandbytes as bnb
|
||||||
import bitsandbytes.functional as F
|
import bitsandbytes.functional as F
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
@ -187,6 +188,8 @@ class MatmulLtState:
|
||||||
use_pool = False
|
use_pool = False
|
||||||
formatB = F.get_special_format_str()
|
formatB = F.get_special_format_str()
|
||||||
|
|
||||||
|
memory_efficient_backward = False
|
||||||
|
|
||||||
def reset_grads(self):
|
def reset_grads(self):
|
||||||
self.CB = None
|
self.CB = None
|
||||||
self.CxB = None
|
self.CxB = None
|
||||||
|
@ -283,6 +286,12 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
|
|
||||||
outlier_idx = torch.unique(coo_tensorA.colidx)
|
outlier_idx = torch.unique(coo_tensorA.colidx)
|
||||||
state.idx = outlier_idx
|
state.idx = outlier_idx
|
||||||
|
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
|
||||||
|
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
|
||||||
|
# # do not use pool for 2nd FFN layer
|
||||||
|
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
|
||||||
|
# else:
|
||||||
|
# state.idx = outlier_idx
|
||||||
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
|
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
|
||||||
state.subB = (
|
state.subB = (
|
||||||
(outliers * state.SCB.view(-1, 1) / 127.0)
|
(outliers * state.SCB.view(-1, 1) / 127.0)
|
||||||
|
@ -332,13 +341,15 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
|
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
|
||||||
return clone_func(output.view(output_shape))
|
return clone_func(output.view(output_shape))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
if ctx.is_empty:
|
if ctx.is_empty:
|
||||||
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
|
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
|
||||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
|
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
|
||||||
|
|
||||||
req_gradA, req_gradB, req_gradBias = ctx.req_grads
|
req_gradA, req_gradB, req_gradBias = ctx.req_grads
|
||||||
assert not req_gradB, "TODO: support weight updates as well"
|
CAt, subA = ctx.tensors
|
||||||
|
SCAt, idx = ctx.tensor_states
|
||||||
|
formatB = ctx.formatB
|
||||||
state = ctx.state
|
state = ctx.state
|
||||||
|
|
||||||
# Cast grad_output to fp16
|
# Cast grad_output to fp16
|
||||||
|
@ -352,11 +363,31 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
|
|
||||||
grad_A = grad_B = grad_bias = None
|
grad_A = grad_B = grad_bias = None
|
||||||
|
|
||||||
|
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
|
||||||
|
if req_gradB:
|
||||||
|
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
|
||||||
|
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
|
||||||
|
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
|
||||||
|
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
|
||||||
|
if state.threshold > 0.0 and subA is not None:
|
||||||
|
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
|
||||||
|
|
||||||
if req_gradA:
|
if req_gradA:
|
||||||
CB = state.CB.half()
|
if state.CBt:
|
||||||
SCB = (state.SCB.unsqueeze(1) / 127.0).half()
|
C32grad, Sgrad = F.transform(Cgrad, "col32")
|
||||||
CB *= SCB
|
if state.CxBt is None:
|
||||||
grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape)
|
state.CxBt, state.SBt = F.transform(
|
||||||
|
state.CBt, to_order=formatB, transpose=True
|
||||||
|
)
|
||||||
|
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
|
||||||
|
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
|
||||||
|
elif state.CB:
|
||||||
|
CB = state.CB.half()
|
||||||
|
SCB = (state.SCB.unsqueeze(1) / 127.0).half()
|
||||||
|
CB *= SCB
|
||||||
|
grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape)
|
||||||
|
else:
|
||||||
|
raise Exception('State must contain either CBt or CB matrix')
|
||||||
|
|
||||||
if req_gradBias:
|
if req_gradBias:
|
||||||
grad_bias = grad_output.sum(0)
|
grad_bias = grad_output.sum(0)
|
||||||
|
@ -367,6 +398,9 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
return grad_A, grad_B, None, grad_bias, None
|
return grad_A, grad_B, None, grad_bias, None
|
||||||
|
|
||||||
|
|
||||||
|
matmul = MatMul8bitLt.apply
|
||||||
|
|
||||||
|
|
||||||
def matmul(
|
def matmul(
|
||||||
A: tensor,
|
A: tensor,
|
||||||
B: tensor,
|
B: tensor,
|
||||||
|
|
|
@ -223,6 +223,7 @@ class Linear8bitLt(nn.Linear):
|
||||||
has_fp16_weights=True,
|
has_fp16_weights=True,
|
||||||
threshold=0.0,
|
threshold=0.0,
|
||||||
index=None,
|
index=None,
|
||||||
|
memory_efficient_backward=False
|
||||||
):
|
):
|
||||||
super(Linear8bitLt, self).__init__(
|
super(Linear8bitLt, self).__init__(
|
||||||
input_features, output_features, bias
|
input_features, output_features, bias
|
||||||
|
@ -232,6 +233,7 @@ class Linear8bitLt(nn.Linear):
|
||||||
|
|
||||||
self.state.threshold = threshold
|
self.state.threshold = threshold
|
||||||
self.state.has_fp16_weights = has_fp16_weights
|
self.state.has_fp16_weights = has_fp16_weights
|
||||||
|
self.state.memory_efficient_backward = memory_efficient_backward
|
||||||
if threshold > 0.0 and not has_fp16_weights:
|
if threshold > 0.0 and not has_fp16_weights:
|
||||||
self.state.use_pool = True
|
self.state.use_pool = True
|
||||||
|
|
||||||
|
@ -255,10 +257,16 @@ class Linear8bitLt(nn.Linear):
|
||||||
|
|
||||||
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
||||||
|
|
||||||
if not self.state.has_fp16_weights and self.state.CxB is not None:
|
if not self.state.has_fp16_weights:
|
||||||
# In this version, we convert 8-bit row major to turing/ampere format at each inference pass
|
if not self.state.memory_efficient_backward and self.state.CB is not None:
|
||||||
# Thus, we delete CxB from the state. TODO: do not store it in the state in the first place.
|
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
||||||
del self.state.CxB
|
# we no longer need the row-major weight
|
||||||
|
del self.state.CB
|
||||||
|
self.weight.data = self.state.CxB
|
||||||
|
elif self.state.memory_efficient_backward and self.state.CxB is not None:
|
||||||
|
# For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
|
||||||
|
# Thus, we delete CxB from the state.
|
||||||
|
del self.state.CxB
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user