2022-08-08 12:20:36 +00:00
|
|
|
import operator
|
2022-07-22 21:41:05 +00:00
|
|
|
import torch
|
|
|
|
import bitsandbytes.functional as F
|
|
|
|
|
2022-08-08 12:20:36 +00:00
|
|
|
from dataclasses import dataclass
|
|
|
|
from functools import reduce # Required in Python 3
|
|
|
|
|
2022-08-08 16:13:22 +00:00
|
|
|
# math.prod not compatible with python < 3.8
|
2022-08-08 12:20:36 +00:00
|
|
|
def prod(iterable):
|
|
|
|
return reduce(operator.mul, iterable, 1)
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
tensor = torch.Tensor
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
"""
|
2022-07-22 21:41:05 +00:00
|
|
|
This class pools outlier dimensions across layers.
|
|
|
|
This is particularly important for small models where outlier features
|
|
|
|
are less systematic and occur with low frequency.
|
2022-08-01 10:31:48 +00:00
|
|
|
"""
|
2022-07-22 21:41:05 +00:00
|
|
|
class GlobalOutlierPooler(object):
|
|
|
|
_instance = None
|
|
|
|
|
|
|
|
def __init__(self):
|
2022-08-01 10:31:48 +00:00
|
|
|
raise RuntimeError("Call get_instance() instead")
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
def initialize(self):
|
|
|
|
self.outliers = set()
|
|
|
|
self.model_dim = None
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def get_instance(cls):
|
|
|
|
if cls._instance is None:
|
|
|
|
cls._instance = cls.__new__(cls)
|
|
|
|
cls._instance.initialize()
|
|
|
|
return cls._instance
|
|
|
|
|
|
|
|
def add_outliers(self, outlier_idx, feature_dim):
|
2022-08-01 10:31:48 +00:00
|
|
|
if self.model_dim is None:
|
|
|
|
self.model_dim = feature_dim
|
|
|
|
if feature_dim != self.model_dim:
|
|
|
|
return # we do not encode outliers for the 2nd FFN layer
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
self.outliers.update(outlier_idx.tolist())
|
|
|
|
|
|
|
|
def get_current_outlier_idx(self):
|
|
|
|
return torch.Tensor(list(self.outliers)).to(torch.int64)
|
|
|
|
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
class MatMul8bit(torch.autograd.Function):
|
2022-07-22 21:41:05 +00:00
|
|
|
@staticmethod
|
2022-08-01 10:31:48 +00:00
|
|
|
def forward(ctx, A, B, out=None, quant_type="vector", precision=[8, 8, 8]):
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
if precision[0] != 8:
|
|
|
|
with torch.no_grad():
|
|
|
|
output = torch.matmul(A, B)
|
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
if len(B.shape) == 2:
|
|
|
|
dim = 0
|
|
|
|
else:
|
|
|
|
dim = 1
|
2022-07-22 21:41:05 +00:00
|
|
|
qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type)
|
|
|
|
qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type)
|
|
|
|
iout = F.igemm(qA, qB)
|
|
|
|
output = F.vectorwise_mm_dequant(iout, SA, SB, A.dtype, quant_type)
|
|
|
|
|
|
|
|
if A.requires_grad or B.requires_grad:
|
|
|
|
ctx.save_for_backward(A, B)
|
|
|
|
|
|
|
|
ctx.quant_type = quant_type
|
|
|
|
ctx.precision = precision
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, grad_output):
|
|
|
|
A, B = ctx.saved_tensors
|
|
|
|
quant_type = ctx.quant_type
|
|
|
|
precision = ctx.precision
|
|
|
|
grad_A = grad_B = None
|
|
|
|
|
|
|
|
if B.requires_grad:
|
|
|
|
if len(A.shape) == 3:
|
|
|
|
dims = [0, 1]
|
|
|
|
# bsi -> ibs
|
|
|
|
permute_dim = [0, 2, 1]
|
|
|
|
else:
|
|
|
|
dims = [0]
|
|
|
|
# bs -> sb
|
|
|
|
permute_dim = [1, 0]
|
|
|
|
|
|
|
|
if precision[1] != 8:
|
|
|
|
with torch.no_grad():
|
|
|
|
grad_B = torch.matmul(A.permute(permute_dim), grad_output)
|
|
|
|
else:
|
|
|
|
if len(B.shape) == 2 and len(A.shape) == 3:
|
|
|
|
grad_output = grad_output.contiguous()
|
2022-08-01 10:31:48 +00:00
|
|
|
if not grad_output.is_contiguous():
|
|
|
|
grad_output.contiguous()
|
|
|
|
qgrad_output, S1 = F.vectorwise_quant(
|
|
|
|
grad_output.view(-1, grad_output.shape[2]),
|
|
|
|
dim=0,
|
|
|
|
quant_type=quant_type,
|
|
|
|
)
|
|
|
|
if not A.is_contiguous():
|
|
|
|
A = A.contiguous()
|
|
|
|
qA, S2 = F.vectorwise_quant(
|
|
|
|
A.view(-1, A.shape[2]), dim=0, quant_type=quant_type
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
igrad_B = F.igemm(qA.t(), qgrad_output)
|
2022-08-01 10:31:48 +00:00
|
|
|
grad_B = F.vectorwise_mm_dequant(
|
|
|
|
igrad_B, S2.t(), S1, grad_output.dtype, quant_type
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
qgrad_output, S1 = F.vectorwise_quant(
|
|
|
|
grad_output, dim=dims, quant_type=quant_type
|
|
|
|
)
|
2022-08-01 16:32:47 +00:00
|
|
|
qA, S2 = F.vectorwise_quant(
|
|
|
|
A, dim=dims, quant_type=quant_type
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
|
2022-08-01 10:31:48 +00:00
|
|
|
grad_B = F.vectorwise_mm_dequant(
|
|
|
|
igrad_B,
|
|
|
|
S2.permute(permute_dim),
|
|
|
|
S1,
|
|
|
|
grad_output.dtype,
|
|
|
|
quant_type,
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
if A.requires_grad:
|
2022-08-01 10:31:48 +00:00
|
|
|
if len(grad_output.shape) == 3:
|
|
|
|
dims = [2]
|
|
|
|
else:
|
|
|
|
dims = [1]
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
if len(B.shape) == 3:
|
|
|
|
# bio -> boi
|
|
|
|
permute_dim = [0, 2, 1]
|
|
|
|
dim_B = dims
|
|
|
|
else:
|
|
|
|
# io -> oi
|
|
|
|
permute_dim = [1, 0]
|
|
|
|
dim_B = [1]
|
|
|
|
|
|
|
|
if precision[2] != 8:
|
|
|
|
with torch.no_grad():
|
|
|
|
grad_A = torch.matmul(grad_output, B.permute(permute_dim))
|
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
qgrad_output, S1 = F.vectorwise_quant(
|
|
|
|
grad_output, dim=dims, quant_type=quant_type
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
|
|
|
|
igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
|
2022-08-01 10:31:48 +00:00
|
|
|
grad_A = F.vectorwise_mm_dequant(
|
2022-08-01 16:32:47 +00:00
|
|
|
igrad_A,
|
|
|
|
S1,
|
|
|
|
S3.permute(permute_dim),
|
|
|
|
grad_output.dtype,
|
|
|
|
quant_type,
|
2022-08-01 10:31:48 +00:00
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
return grad_A, grad_B, None, None, None
|
|
|
|
|
|
|
|
|
|
|
|
mm_cublas = MatMul8bit.apply
|
|
|
|
bmm_cublas = MatMul8bit.apply
|
|
|
|
matmul_cublas = MatMul8bit.apply
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
@dataclass
|
|
|
|
class MatmulLtState:
|
|
|
|
CB = None
|
|
|
|
CxB = None
|
|
|
|
SB = None
|
|
|
|
SCB = None
|
|
|
|
|
|
|
|
CxBt = None
|
|
|
|
SBt = None
|
|
|
|
CBt = None
|
|
|
|
|
|
|
|
subB = None
|
|
|
|
|
|
|
|
outlier_pool = None
|
|
|
|
has_accumulated_gradients = False
|
|
|
|
threshold = 0.0
|
|
|
|
idx = None
|
|
|
|
is_training = True
|
|
|
|
has_fp16_weights = True
|
|
|
|
use_pool = False
|
|
|
|
formatB = F.get_special_format_str()
|
|
|
|
|
|
|
|
def reset_grads(self):
|
|
|
|
self.CB = None
|
|
|
|
self.CxB = None
|
|
|
|
self.SB = None
|
|
|
|
self.SCB = None
|
|
|
|
|
|
|
|
self.CxBt = None
|
|
|
|
self.SBt = None
|
|
|
|
self.CBt = None
|
|
|
|
|
|
|
|
|
|
|
|
class MatMul8bitLt(torch.autograd.Function):
|
|
|
|
@staticmethod
|
2022-08-16 19:00:54 +00:00
|
|
|
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
|
2022-08-03 18:54:01 +00:00
|
|
|
# default to pytorch behavior if inputs are empty
|
|
|
|
ctx.is_empty = False
|
2022-08-08 12:20:36 +00:00
|
|
|
if prod(A.shape) == 0:
|
2022-08-03 18:54:01 +00:00
|
|
|
ctx.is_empty = True
|
|
|
|
ctx.A = A
|
|
|
|
ctx.B = B
|
2022-08-16 19:00:54 +00:00
|
|
|
ctx.bias = bias
|
2022-08-03 18:54:01 +00:00
|
|
|
if A.shape[-1] == B.shape[0]:
|
|
|
|
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
|
|
|
|
else:
|
|
|
|
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
# 1. Quantize A
|
|
|
|
# 2. Quantize B
|
|
|
|
# 3. Matmul
|
|
|
|
# 4. Mixed-precision decomposition matmul
|
|
|
|
# 5. Save state
|
|
|
|
requires_gradA = A.requires_grad
|
|
|
|
requires_gradB = B.requires_grad
|
2022-08-16 19:00:54 +00:00
|
|
|
requires_gradBias = bias is not None and bias.requires_grad
|
2022-07-22 21:41:05 +00:00
|
|
|
formatB = state.formatB
|
|
|
|
input_shape = A.shape
|
2022-08-01 10:31:48 +00:00
|
|
|
if state.outlier_pool is None:
|
|
|
|
state.outlier_pool = GlobalOutlierPooler.get_instance()
|
|
|
|
assert (
|
|
|
|
A.dtype == torch.float16
|
|
|
|
), f"The input data type needs to be fp16 but {A.dtype} was found!"
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
# 1. Quantize A
|
2022-08-01 10:31:48 +00:00
|
|
|
if len(A.shape) == 3:
|
|
|
|
A = A.view(-1, A.shape[-1]).contiguous()
|
2022-08-01 16:32:47 +00:00
|
|
|
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
|
|
|
|
A, threshold=state.threshold
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
if state.threshold > 0.0 and coo_tensorA is not None:
|
|
|
|
if state.has_fp16_weights:
|
|
|
|
idx = torch.unique(coo_tensorA.colidx).long()
|
|
|
|
CA[:, idx] = 0
|
|
|
|
CAt[:, idx] = 0
|
|
|
|
subA = A[:, idx]
|
|
|
|
state.subB = B[:, idx].t().contiguous()
|
|
|
|
state.idx = idx
|
|
|
|
else:
|
|
|
|
if state.CxB is None:
|
|
|
|
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
|
|
|
|
# we also need to convert it to the turing/ampere format
|
2022-08-16 19:00:54 +00:00
|
|
|
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
2022-07-22 21:41:05 +00:00
|
|
|
else:
|
|
|
|
if not state.has_fp16_weights and state.CxB is None:
|
|
|
|
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
|
|
|
subA = None
|
|
|
|
|
|
|
|
# 2. Quantize B
|
|
|
|
if state.has_fp16_weights:
|
2022-08-01 10:31:48 +00:00
|
|
|
has_grad = True if (getattr(B, "grad", None) is not None) else False
|
2022-07-22 21:41:05 +00:00
|
|
|
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
|
2022-08-01 10:31:48 +00:00
|
|
|
if is_transposed:
|
|
|
|
B = B.contiguous()
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
if (state.is_training and not has_grad) or state.CxB is None:
|
|
|
|
state.reset_grads()
|
2022-08-01 16:32:47 +00:00
|
|
|
(
|
|
|
|
CB,
|
|
|
|
state.CBt,
|
|
|
|
state.SCB,
|
|
|
|
state.SCBt,
|
|
|
|
coo_tensorB,
|
|
|
|
) = F.double_quant(B)
|
2022-07-22 21:41:05 +00:00
|
|
|
state.CxB, state.SB = F.transform(CB, to_order=formatB)
|
|
|
|
else:
|
|
|
|
has_grad = False
|
|
|
|
|
2022-07-27 02:15:35 +00:00
|
|
|
if coo_tensorA is not None and not state.has_fp16_weights:
|
|
|
|
# extract outliers
|
|
|
|
|
|
|
|
outlier_idx = torch.unique(coo_tensorA.colidx)
|
2022-07-27 08:46:35 +00:00
|
|
|
state.idx = outlier_idx
|
2022-08-01 10:31:48 +00:00
|
|
|
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
|
|
|
|
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
|
2022-07-27 08:46:35 +00:00
|
|
|
# # do not use pool for 2nd FFN layer
|
|
|
|
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
|
2022-08-01 10:31:48 +00:00
|
|
|
# else:
|
2022-07-27 08:46:35 +00:00
|
|
|
# state.idx = outlier_idx
|
|
|
|
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
|
2022-08-01 10:31:48 +00:00
|
|
|
state.subB = (
|
2022-08-01 16:32:47 +00:00
|
|
|
(outliers * state.SCB.view(-1, 1) / 127.0)
|
|
|
|
.t()
|
|
|
|
.contiguous()
|
|
|
|
.half()
|
2022-08-01 10:31:48 +00:00
|
|
|
)
|
2022-07-27 02:15:35 +00:00
|
|
|
CA[:, state.idx.long()] = 0
|
|
|
|
CAt[:, state.idx.long()] = 0
|
|
|
|
subA = A[:, state.idx.long()]
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
shapeB = state.SB[0]
|
|
|
|
|
|
|
|
if len(input_shape) == 3:
|
|
|
|
output_shape = (input_shape[0], input_shape[1], shapeB[0])
|
|
|
|
else:
|
|
|
|
output_shape = (input_shape[0], shapeB[0])
|
|
|
|
|
|
|
|
# 3. Matmul
|
2022-08-01 10:31:48 +00:00
|
|
|
C32A, SA = F.transform(CA, "col32")
|
2022-07-22 21:41:05 +00:00
|
|
|
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
|
2022-08-16 19:00:54 +00:00
|
|
|
# we apply the fused bias here
|
|
|
|
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
# 4. Mixed-precision decomposition matmul
|
2022-07-27 08:46:35 +00:00
|
|
|
if coo_tensorA is not None and subA is not None:
|
2022-07-22 21:41:05 +00:00
|
|
|
output += torch.matmul(subA, state.subB)
|
|
|
|
|
|
|
|
# 5. Save state
|
|
|
|
ctx.state = state
|
|
|
|
|
|
|
|
ctx.formatB = formatB
|
|
|
|
ctx.grad_shape = input_shape
|
2022-08-16 19:00:54 +00:00
|
|
|
ctx.req_grads = [requires_gradA, requires_gradB, requires_gradBias]
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
if requires_gradA or requires_gradB:
|
|
|
|
ctx.tensors = (CAt, subA)
|
|
|
|
ctx.tensor_states = (SCAt, state.idx)
|
|
|
|
else:
|
|
|
|
ctx.tensors = [None, None]
|
|
|
|
ctx.tensor_states = (None, None)
|
|
|
|
ctx.save_for_backward(None, None)
|
|
|
|
|
2022-08-16 19:00:54 +00:00
|
|
|
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
|
|
|
|
#clone_func = torch.clone
|
2022-07-22 21:41:05 +00:00
|
|
|
return clone_func(output.view(output_shape))
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, grad_output):
|
2022-08-03 18:54:01 +00:00
|
|
|
if ctx.is_empty:
|
2022-08-16 19:00:54 +00:00
|
|
|
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
|
|
|
|
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
|
|
|
|
req_gradA, req_gradB, req_gradBias = ctx.req_grads
|
2022-07-22 21:41:05 +00:00
|
|
|
CAt, subA = ctx.tensors
|
|
|
|
SCAt, idx = ctx.tensor_states
|
|
|
|
formatB = ctx.formatB
|
|
|
|
state = ctx.state
|
2022-08-01 16:32:47 +00:00
|
|
|
assert (
|
|
|
|
state.has_fp16_weights
|
|
|
|
), "Backprop only supported for fp16 weights."
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
if len(grad_output.shape) == 3:
|
2022-08-01 16:32:47 +00:00
|
|
|
grad_output = grad_output.view(
|
|
|
|
-1, grad_output.shape[-1]
|
|
|
|
).contiguous()
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-16 19:00:54 +00:00
|
|
|
grad_A = grad_B = grad_bias = None
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
|
|
|
|
if req_gradB:
|
|
|
|
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
|
2022-08-01 10:31:48 +00:00
|
|
|
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
|
2022-07-22 21:41:05 +00:00
|
|
|
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
|
|
|
|
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
|
|
|
|
if state.threshold > 0.0 and subA is not None:
|
|
|
|
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
|
|
|
|
|
|
|
|
if req_gradA:
|
2022-08-01 10:31:48 +00:00
|
|
|
C32grad, Sgrad = F.transform(Cgrad, "col32")
|
2022-07-22 21:41:05 +00:00
|
|
|
if state.CxBt is None:
|
2022-08-01 10:31:48 +00:00
|
|
|
state.CxBt, state.SBt = F.transform(
|
|
|
|
state.CBt, to_order=formatB, transpose=True
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
|
2022-08-16 19:00:54 +00:00
|
|
|
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
|
|
|
|
|
|
|
|
if req_gradBias:
|
|
|
|
grad_bias = grad_output.sum(0)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-16 19:00:54 +00:00
|
|
|
return grad_A, grad_B, None, grad_bias, None
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
def matmul(
|
2022-08-01 16:32:47 +00:00
|
|
|
A: tensor,
|
|
|
|
B: tensor,
|
|
|
|
out: tensor = None,
|
|
|
|
state: MatmulLtState = None,
|
|
|
|
threshold=0.0,
|
2022-08-16 19:00:54 +00:00
|
|
|
bias=None
|
2022-08-01 10:31:48 +00:00
|
|
|
):
|
2022-07-22 21:41:05 +00:00
|
|
|
state = state or MatmulLtState()
|
|
|
|
if threshold > 0.0:
|
|
|
|
state.threshold = threshold
|
2022-08-16 19:00:54 +00:00
|
|
|
return MatMul8bitLt.apply(A, B, out, bias, state)
|