Added fused bias to matmullt.
This commit is contained in:
parent
dede343033
commit
de354f7ded
|
@ -201,13 +201,14 @@ class MatmulLtState:
|
||||||
|
|
||||||
class MatMul8bitLt(torch.autograd.Function):
|
class MatMul8bitLt(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, A, B, out=None, state=MatmulLtState()):
|
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
|
||||||
# default to pytorch behavior if inputs are empty
|
# default to pytorch behavior if inputs are empty
|
||||||
ctx.is_empty = False
|
ctx.is_empty = False
|
||||||
if prod(A.shape) == 0:
|
if prod(A.shape) == 0:
|
||||||
ctx.is_empty = True
|
ctx.is_empty = True
|
||||||
ctx.A = A
|
ctx.A = A
|
||||||
ctx.B = B
|
ctx.B = B
|
||||||
|
ctx.bias = bias
|
||||||
if A.shape[-1] == B.shape[0]:
|
if A.shape[-1] == B.shape[0]:
|
||||||
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
|
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
|
||||||
else:
|
else:
|
||||||
|
@ -220,6 +221,7 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
# 5. Save state
|
# 5. Save state
|
||||||
requires_gradA = A.requires_grad
|
requires_gradA = A.requires_grad
|
||||||
requires_gradB = B.requires_grad
|
requires_gradB = B.requires_grad
|
||||||
|
requires_gradBias = bias is not None and bias.requires_grad
|
||||||
formatB = state.formatB
|
formatB = state.formatB
|
||||||
input_shape = A.shape
|
input_shape = A.shape
|
||||||
if state.outlier_pool is None:
|
if state.outlier_pool is None:
|
||||||
|
@ -247,28 +249,7 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
if state.CxB is None:
|
if state.CxB is None:
|
||||||
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
|
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
|
||||||
# we also need to convert it to the turing/ampere format
|
# we also need to convert it to the turing/ampere format
|
||||||
state.CxB, state.SB = F.transform(
|
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
||||||
state.CB, to_order=formatB
|
|
||||||
)
|
|
||||||
# state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
|
|
||||||
# if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
|
|
||||||
# # generate outlier index and subB
|
|
||||||
# outlier_idx = torch.unique(coo_tensorA.colidx).long()
|
|
||||||
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
|
|
||||||
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
|
|
||||||
# # do not use pool for 2nd FFN layer
|
|
||||||
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
|
|
||||||
# else:
|
|
||||||
# state.idx = outlier_idx
|
|
||||||
# state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()
|
|
||||||
|
|
||||||
# if state.idx is not None:
|
|
||||||
# # extract outliers
|
|
||||||
# CA[:, state.idx] = 0
|
|
||||||
# CAt[:, state.idx] = 0
|
|
||||||
# subA = A[:, state.idx]
|
|
||||||
# else:
|
|
||||||
# subA = None
|
|
||||||
else:
|
else:
|
||||||
if not state.has_fp16_weights and state.CxB is None:
|
if not state.has_fp16_weights and state.CxB is None:
|
||||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
||||||
|
@ -326,7 +307,8 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
# 3. Matmul
|
# 3. Matmul
|
||||||
C32A, SA = F.transform(CA, "col32")
|
C32A, SA = F.transform(CA, "col32")
|
||||||
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
|
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
|
||||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB)
|
# we apply the fused bias here
|
||||||
|
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
|
||||||
|
|
||||||
# 4. Mixed-precision decomposition matmul
|
# 4. Mixed-precision decomposition matmul
|
||||||
if coo_tensorA is not None and subA is not None:
|
if coo_tensorA is not None and subA is not None:
|
||||||
|
@ -337,7 +319,7 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
|
|
||||||
ctx.formatB = formatB
|
ctx.formatB = formatB
|
||||||
ctx.grad_shape = input_shape
|
ctx.grad_shape = input_shape
|
||||||
ctx.req_grads = [requires_gradA, requires_gradB]
|
ctx.req_grads = [requires_gradA, requires_gradB, requires_gradBias]
|
||||||
|
|
||||||
if requires_gradA or requires_gradB:
|
if requires_gradA or requires_gradB:
|
||||||
ctx.tensors = (CAt, subA)
|
ctx.tensors = (CAt, subA)
|
||||||
|
@ -347,15 +329,16 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
ctx.tensor_states = (None, None)
|
ctx.tensor_states = (None, None)
|
||||||
ctx.save_for_backward(None, None)
|
ctx.save_for_backward(None, None)
|
||||||
|
|
||||||
# clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
|
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
|
||||||
clone_func = torch.clone
|
#clone_func = torch.clone
|
||||||
return clone_func(output.view(output_shape))
|
return clone_func(output.view(output_shape))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
if ctx.is_empty:
|
if ctx.is_empty:
|
||||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None
|
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
|
||||||
req_gradA, req_gradB = ctx.req_grads
|
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
|
||||||
|
req_gradA, req_gradB, req_gradBias = ctx.req_grads
|
||||||
CAt, subA = ctx.tensors
|
CAt, subA = ctx.tensors
|
||||||
SCAt, idx = ctx.tensor_states
|
SCAt, idx = ctx.tensor_states
|
||||||
formatB = ctx.formatB
|
formatB = ctx.formatB
|
||||||
|
@ -369,7 +352,7 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
-1, grad_output.shape[-1]
|
-1, grad_output.shape[-1]
|
||||||
).contiguous()
|
).contiguous()
|
||||||
|
|
||||||
grad_A = grad_B = None
|
grad_A = grad_B = grad_bias = None
|
||||||
|
|
||||||
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
|
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
|
||||||
if req_gradB:
|
if req_gradB:
|
||||||
|
@ -387,11 +370,12 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
state.CBt, to_order=formatB, transpose=True
|
state.CBt, to_order=formatB, transpose=True
|
||||||
)
|
)
|
||||||
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
|
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
|
||||||
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(
|
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
|
||||||
ctx.grad_shape
|
|
||||||
)
|
|
||||||
|
|
||||||
return grad_A, grad_B, None, None
|
if req_gradBias:
|
||||||
|
grad_bias = grad_output.sum(0)
|
||||||
|
|
||||||
|
return grad_A, grad_B, None, grad_bias, None
|
||||||
|
|
||||||
|
|
||||||
matmul = MatMul8bitLt.apply
|
matmul = MatMul8bitLt.apply
|
||||||
|
@ -403,8 +387,9 @@ def matmul(
|
||||||
out: tensor = None,
|
out: tensor = None,
|
||||||
state: MatmulLtState = None,
|
state: MatmulLtState = None,
|
||||||
threshold=0.0,
|
threshold=0.0,
|
||||||
|
bias=None
|
||||||
):
|
):
|
||||||
state = state or MatmulLtState()
|
state = state or MatmulLtState()
|
||||||
if threshold > 0.0:
|
if threshold > 0.0:
|
||||||
state.threshold = threshold
|
state.threshold = threshold
|
||||||
return MatMul8bitLt.apply(A, B, out, state)
|
return MatMul8bitLt.apply(A, B, out, bias, state)
|
||||||
|
|
|
@ -235,9 +235,7 @@ class Linear8bitLt(nn.Linear):
|
||||||
if threshold > 0.0 and not has_fp16_weights:
|
if threshold > 0.0 and not has_fp16_weights:
|
||||||
self.state.use_pool = True
|
self.state.use_pool = True
|
||||||
|
|
||||||
self.weight = Int8Params(
|
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights)
|
||||||
self.weight.data, has_fp16_weights=has_fp16_weights
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_8bit_state(self):
|
def init_8bit_state(self):
|
||||||
self.state.CB = self.weight.CB
|
self.state.CB = self.weight.CB
|
||||||
|
@ -250,13 +248,12 @@ class Linear8bitLt(nn.Linear):
|
||||||
|
|
||||||
if self.weight.CB is not None:
|
if self.weight.CB is not None:
|
||||||
self.init_8bit_state()
|
self.init_8bit_state()
|
||||||
|
if self.bias.dtype != torch.float16:
|
||||||
|
self.bias.data = self.bias.data.half()
|
||||||
# assert not self.state.has_fp16_weights
|
# assert not self.state.has_fp16_weights
|
||||||
# if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
|
# if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
|
||||||
|
|
||||||
out = bnb.matmul(x, self.weight, state=self.state)
|
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
||||||
|
|
||||||
if self.bias is not None:
|
|
||||||
out += self.bias.unsqueeze(0).expand_as(out)
|
|
||||||
|
|
||||||
if not self.state.has_fp16_weights and self.state.CB is not None:
|
if not self.state.has_fp16_weights and self.state.CB is not None:
|
||||||
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from itertools import product
|
from itertools import product, permutations
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
@ -241,11 +241,20 @@ decomp = [0.0, 6.0]
|
||||||
funcs = [(torch.matmul, bnb.matmul)]
|
funcs = [(torch.matmul, bnb.matmul)]
|
||||||
str_funcs = ["matmul"]
|
str_funcs = ["matmul"]
|
||||||
req_grad = [(False, False), (True, False), (True, True), (False, True)]
|
req_grad = [(False, False), (True, False), (True, True), (False, True)]
|
||||||
req_grad_str = ["FF", "TF", "TT", "FT"]
|
req_grad = list(product([True, False], repeat=3))
|
||||||
|
req_grad_str = []
|
||||||
|
for c in req_grad:
|
||||||
|
strval = ''
|
||||||
|
for v in c:
|
||||||
|
if v == True: strval += 'T'
|
||||||
|
else: strval += 'F'
|
||||||
|
req_grad_str.append(strval)
|
||||||
|
|
||||||
transpose = [(False, True), (False, False)]
|
transpose = [(False, True), (False, False)]
|
||||||
str_transpose = ["NT", "NN"]
|
str_transpose = ["NT", "NN"]
|
||||||
dtype = [torch.float16]
|
dtype = [torch.float16]
|
||||||
has_fp16_weights = [True, False]
|
has_fp16_weights = [True, False]
|
||||||
|
has_bias = [True, False]
|
||||||
values = list(
|
values = list(
|
||||||
product(
|
product(
|
||||||
dim1,
|
dim1,
|
||||||
|
@ -258,6 +267,7 @@ values = list(
|
||||||
transpose,
|
transpose,
|
||||||
decomp,
|
decomp,
|
||||||
has_fp16_weights,
|
has_fp16_weights,
|
||||||
|
has_bias
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
str_values = list(
|
str_values = list(
|
||||||
|
@ -272,18 +282,14 @@ str_values = list(
|
||||||
str_transpose,
|
str_transpose,
|
||||||
decomp,
|
decomp,
|
||||||
has_fp16_weights,
|
has_fp16_weights,
|
||||||
|
has_bias
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
names = [
|
names = ["dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}_has_bias_{10}".format(*vals) for vals in str_values]
|
||||||
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}".format(
|
|
||||||
*vals
|
|
||||||
)
|
|
||||||
for vals in str_values
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights",
|
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias",
|
||||||
values,
|
values,
|
||||||
ids=names,
|
ids=names,
|
||||||
)
|
)
|
||||||
|
@ -298,10 +304,14 @@ def test_matmullt(
|
||||||
transpose,
|
transpose,
|
||||||
decomp,
|
decomp,
|
||||||
has_fp16_weights,
|
has_fp16_weights,
|
||||||
|
has_bias
|
||||||
):
|
):
|
||||||
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
|
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
|
||||||
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
|
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
|
||||||
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
|
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
|
||||||
|
if has_bias == False:
|
||||||
|
req_grad = list(req_grad)
|
||||||
|
req_grad[2] = False
|
||||||
|
|
||||||
for i in range(k):
|
for i in range(k):
|
||||||
|
|
||||||
|
@ -322,6 +332,11 @@ def test_matmullt(
|
||||||
requires_grad=req_grad[1],
|
requires_grad=req_grad[1],
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
bias = None
|
||||||
|
bias2 = None
|
||||||
|
if has_bias:
|
||||||
|
bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2])
|
||||||
|
bias2 = bias.clone()
|
||||||
torch.nn.init.xavier_uniform_(B)
|
torch.nn.init.xavier_uniform_(B)
|
||||||
B2 = B.clone()
|
B2 = B.clone()
|
||||||
|
|
||||||
|
@ -342,10 +357,13 @@ def test_matmullt(
|
||||||
|
|
||||||
if not transpose[0] and transpose[1]:
|
if not transpose[0] and transpose[1]:
|
||||||
out_torch = funcs[0](A, B.t())
|
out_torch = funcs[0](A, B.t())
|
||||||
out_bnb = funcs[1](A, B2, state=state)
|
out_bnb = funcs[1](A, B2, state=state, bias=bias2)
|
||||||
elif not transpose[0] and not transpose[1]:
|
elif not transpose[0] and not transpose[1]:
|
||||||
out_torch = funcs[0](A, B)
|
out_torch = funcs[0](A, B)
|
||||||
out_bnb = funcs[1](A, B2.t(), state=state)
|
out_bnb = funcs[1](A, B2.t(), state=state, bias=bias2)
|
||||||
|
|
||||||
|
if has_bias:
|
||||||
|
out_torch += bias
|
||||||
|
|
||||||
n = out_bnb.numel()
|
n = out_bnb.numel()
|
||||||
err = torch.abs(out_bnb - out_torch).mean().item()
|
err = torch.abs(out_bnb - out_torch).mean().item()
|
||||||
|
@ -367,6 +385,9 @@ def test_matmullt(
|
||||||
gradB1 = B.grad
|
gradB1 = B.grad
|
||||||
A.grad = None
|
A.grad = None
|
||||||
B.grad = None
|
B.grad = None
|
||||||
|
if has_bias:
|
||||||
|
gradBias1 = bias.grad
|
||||||
|
bias.grad = None
|
||||||
|
|
||||||
loss_torch = torch.nn.functional.mse_loss(
|
loss_torch = torch.nn.functional.mse_loss(
|
||||||
out_torch, target
|
out_torch, target
|
||||||
|
@ -376,6 +397,9 @@ def test_matmullt(
|
||||||
gradB2 = B.grad
|
gradB2 = B.grad
|
||||||
A.grad = None
|
A.grad = None
|
||||||
B.grad = None
|
B.grad = None
|
||||||
|
if has_bias:
|
||||||
|
gradBias2 = bias.grad
|
||||||
|
bias.grad = None
|
||||||
|
|
||||||
if req_grad[0]:
|
if req_grad[0]:
|
||||||
torch.testing.assert_allclose(
|
torch.testing.assert_allclose(
|
||||||
|
@ -397,3 +421,6 @@ def test_matmullt(
|
||||||
torch.testing.assert_allclose(
|
torch.testing.assert_allclose(
|
||||||
gradB1, gradB2, atol=0.18, rtol=0.3
|
gradB1, gradB2, atol=0.18, rtol=0.3
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if req_grad[2]:
|
||||||
|
torch.testing.assert_allclose(gradBias1, gradBias2)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user