Added missing triton and fp8 files.
This commit is contained in:
parent
ec1ea63711
commit
e67bfccbcd
7
bitsandbytes/research/__init__.py
Normal file
7
bitsandbytes/research/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
|
||||||
|
from .autograd._functions import (
|
||||||
|
matmul_fp8,
|
||||||
|
switchback_bnb,
|
||||||
|
matmul_fp8_global,
|
||||||
|
matmul_fp8_mixed,
|
||||||
|
)
|
0
bitsandbytes/research/autograd/__init__.py
Normal file
0
bitsandbytes/research/autograd/__init__.py
Normal file
493
bitsandbytes/research/autograd/_functions.py
Normal file
493
bitsandbytes/research/autograd/_functions.py
Normal file
|
@ -0,0 +1,493 @@
|
||||||
|
import operator
|
||||||
|
import warnings
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import reduce # Required in Python 3
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import bitsandbytes.functional as F
|
||||||
|
|
||||||
|
from bitsandbytes.autograd._functions import MatmulLtState, GlobalOutlierPooler
|
||||||
|
|
||||||
|
|
||||||
|
# math.prod not compatible with python < 3.8
|
||||||
|
def prod(iterable):
|
||||||
|
return reduce(operator.mul, iterable, 1)
|
||||||
|
|
||||||
|
tensor = torch.Tensor
|
||||||
|
|
||||||
|
class MatMulFP8(torch.autograd.Function):
|
||||||
|
# forward is the same, but we added the fallback for pre-turing GPUs
|
||||||
|
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024):
|
||||||
|
# default of pytorch behavior if inputs are empty
|
||||||
|
ctx.is_empty = False
|
||||||
|
if prod(A.shape) == 0:
|
||||||
|
ctx.is_empty = True
|
||||||
|
ctx.A = A
|
||||||
|
ctx.B = B
|
||||||
|
|
||||||
|
B_shape = B.shape
|
||||||
|
if A.shape[-1] == B_shape[0]:
|
||||||
|
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
|
||||||
|
else:
|
||||||
|
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
|
||||||
|
|
||||||
|
# 1. Dequantize
|
||||||
|
# 2. MatmulnN
|
||||||
|
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=bsz)
|
||||||
|
fp8A = F.dequantize_blockwise(cA, state, blocksize=bsz).to(A.dtype)
|
||||||
|
|
||||||
|
cB, state = F.quantize(B.float(), code=fw_code)
|
||||||
|
fp8B = F.dequantize(cB, state).to(B.dtype)
|
||||||
|
|
||||||
|
output = torch.matmul(fp8A, fp8B)
|
||||||
|
|
||||||
|
# output is half
|
||||||
|
|
||||||
|
# 3. Save state
|
||||||
|
ctx.fw_code = fw_code
|
||||||
|
ctx.bw_code = bw_code
|
||||||
|
ctx.bsz = bsz
|
||||||
|
ctx.bsz2 = bsz2
|
||||||
|
ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype
|
||||||
|
|
||||||
|
if any(ctx.needs_input_grad[:2]):
|
||||||
|
# NOTE: we send back A, and re-quant.
|
||||||
|
ctx.tensors = (A, fp8B)
|
||||||
|
else:
|
||||||
|
ctx.tensors = (None, None)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
if ctx.is_empty:
|
||||||
|
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None, None
|
||||||
|
|
||||||
|
req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad
|
||||||
|
A, B = ctx.tensors
|
||||||
|
|
||||||
|
grad_A, grad_B = None, None
|
||||||
|
|
||||||
|
cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz2)
|
||||||
|
fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz2).to(grad_output.dtype)
|
||||||
|
|
||||||
|
cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
|
||||||
|
fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)
|
||||||
|
|
||||||
|
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
||||||
|
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
|
||||||
|
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
|
||||||
|
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
|
||||||
|
|
||||||
|
# not supported by PyTorch. TODO: create work-around
|
||||||
|
if req_gradA:
|
||||||
|
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)
|
||||||
|
|
||||||
|
if req_gradB:
|
||||||
|
if len(A.shape) == 3:
|
||||||
|
At = A.transpose(2, 1).contiguous()
|
||||||
|
else:
|
||||||
|
At = A.transpose(1, 0).contiguous()
|
||||||
|
cA, state = F.quantize(At.float(), code=ctx.fw_code)
|
||||||
|
fp8At = F.dequantize(cA, state).to(A.dtype)
|
||||||
|
grad_B = torch.matmul(fp8At.to(fp8out_2.dtype), fp8out_2).to(B.dtype)
|
||||||
|
|
||||||
|
return grad_A, grad_B, None, None, None, None, None
|
||||||
|
|
||||||
|
class MatMulFP8Mixed(torch.autograd.Function):
|
||||||
|
# forward is the same, but we added the fallback for pre-turing GPUs
|
||||||
|
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024):
|
||||||
|
# default of pytorch behavior if inputs are empty
|
||||||
|
ctx.is_empty = False
|
||||||
|
if prod(A.shape) == 0:
|
||||||
|
ctx.is_empty = True
|
||||||
|
ctx.A = A
|
||||||
|
ctx.B = B
|
||||||
|
|
||||||
|
B_shape = B.shape
|
||||||
|
if A.shape[-1] == B_shape[0]:
|
||||||
|
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
|
||||||
|
else:
|
||||||
|
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
|
||||||
|
|
||||||
|
# 1. Dequantize
|
||||||
|
# 2. MatmulnN
|
||||||
|
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=bsz)
|
||||||
|
fp8A = F.dequantize_blockwise(cA, state, blocksize=bsz).to(A.dtype)
|
||||||
|
|
||||||
|
cB, state = F.quantize(B.float(), code=fw_code)
|
||||||
|
fp8B = F.dequantize(cB, state).to(B.dtype)
|
||||||
|
|
||||||
|
output = torch.matmul(fp8A, fp8B)
|
||||||
|
|
||||||
|
# output is half
|
||||||
|
|
||||||
|
# 3. Save state
|
||||||
|
ctx.fw_code = fw_code
|
||||||
|
ctx.bw_code = bw_code
|
||||||
|
ctx.bsz = bsz
|
||||||
|
ctx.bsz2 = bsz2
|
||||||
|
ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype
|
||||||
|
|
||||||
|
if any(ctx.needs_input_grad[:2]):
|
||||||
|
# NOTE: we send back A, and re-quant.
|
||||||
|
ctx.tensors = (A, fp8B)
|
||||||
|
else:
|
||||||
|
ctx.tensors = (None, None)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
if ctx.is_empty:
|
||||||
|
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None, None
|
||||||
|
|
||||||
|
req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad
|
||||||
|
A, B = ctx.tensors
|
||||||
|
|
||||||
|
grad_A, grad_B = None, None
|
||||||
|
|
||||||
|
# TODO: Fix blocksize to be output_dim
|
||||||
|
cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz2)
|
||||||
|
fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz2).to(grad_output.dtype)
|
||||||
|
|
||||||
|
# cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
|
||||||
|
# fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)
|
||||||
|
|
||||||
|
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
||||||
|
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
|
||||||
|
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
|
||||||
|
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
|
||||||
|
|
||||||
|
# not supported by PyTorch. TODO: create work-around
|
||||||
|
if req_gradA:
|
||||||
|
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)
|
||||||
|
|
||||||
|
if req_gradB:
|
||||||
|
At = A.transpose(2, 1).contiguous()
|
||||||
|
# cA, state = F.quantize(At.float(), code=ctx.fw_code)
|
||||||
|
# fp8At = F.dequantize(cA, state).to(A.dtype)
|
||||||
|
grad_B = torch.matmul(At.to(grad_output.dtype), grad_output).to(B.dtype)
|
||||||
|
|
||||||
|
return grad_A, grad_B, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
class MatMulFP8Global(torch.autograd.Function):
|
||||||
|
# forward is the same, but we added the fallback for pre-turing GPUs
|
||||||
|
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024):
|
||||||
|
# default of pytorch behavior if inputs are empty
|
||||||
|
ctx.is_empty = False
|
||||||
|
if prod(A.shape) == 0:
|
||||||
|
ctx.is_empty = True
|
||||||
|
ctx.A = A
|
||||||
|
ctx.B = B
|
||||||
|
|
||||||
|
B_shape = B.shape
|
||||||
|
if A.shape[-1] == B_shape[0]:
|
||||||
|
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
|
||||||
|
else:
|
||||||
|
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
|
||||||
|
|
||||||
|
# 1. Dequantize
|
||||||
|
# 2. MatmulnN
|
||||||
|
cA, state = F.quantize(A.float(), code=fw_code)
|
||||||
|
fp8A = F.dequantize(cA, state).to(A.dtype)
|
||||||
|
|
||||||
|
cB, state = F.quantize(B.float(), code=fw_code)
|
||||||
|
fp8B = F.dequantize(cB, state).to(B.dtype)
|
||||||
|
|
||||||
|
output = torch.matmul(fp8A, fp8B)
|
||||||
|
|
||||||
|
# output is half
|
||||||
|
|
||||||
|
# 3. Save state
|
||||||
|
ctx.fw_code = fw_code
|
||||||
|
ctx.bw_code = bw_code
|
||||||
|
ctx.bsz = bsz
|
||||||
|
ctx.bsz2 = bsz2
|
||||||
|
ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype
|
||||||
|
|
||||||
|
if any(ctx.needs_input_grad[:2]):
|
||||||
|
# NOTE: we send back A, and re-quant.
|
||||||
|
ctx.tensors = (A, fp8B)
|
||||||
|
else:
|
||||||
|
ctx.tensors = (None, None)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
if ctx.is_empty:
|
||||||
|
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None, None
|
||||||
|
|
||||||
|
req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad
|
||||||
|
A, B = ctx.tensors
|
||||||
|
|
||||||
|
grad_A, grad_B = None, None
|
||||||
|
|
||||||
|
# TODO: Fix blocksize to be output_dim
|
||||||
|
cgrad_out, state = F.quantize(grad_output.float(), code=ctx.bw_code)
|
||||||
|
fp8out = F.dequantize(cgrad_out, state).to(grad_output.dtype)
|
||||||
|
|
||||||
|
# cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
|
||||||
|
# fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)
|
||||||
|
|
||||||
|
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
||||||
|
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
|
||||||
|
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
|
||||||
|
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
|
||||||
|
|
||||||
|
# not supported by PyTorch. TODO: create work-around
|
||||||
|
if req_gradA:
|
||||||
|
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)
|
||||||
|
|
||||||
|
if req_gradB:
|
||||||
|
At = A.transpose(2, 1).contiguous()
|
||||||
|
cA, state = F.quantize(At.float(), code=ctx.fw_code)
|
||||||
|
fp8At = F.dequantize(cA, state).to(A.dtype)
|
||||||
|
grad_B = torch.matmul(fp8At.to(fp8out.dtype), fp8out).to(B.dtype)
|
||||||
|
|
||||||
|
return grad_A, grad_B, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
class MatMul8bitMixed(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
|
||||||
|
# default to pytorch behavior if inputs are empty
|
||||||
|
ctx.is_empty = False
|
||||||
|
if prod(A.shape) == 0:
|
||||||
|
ctx.is_empty = True
|
||||||
|
ctx.A = A
|
||||||
|
ctx.B = B
|
||||||
|
ctx.bias = bias
|
||||||
|
if A.shape[-1] == B.shape[0]:
|
||||||
|
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device)
|
||||||
|
else:
|
||||||
|
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device)
|
||||||
|
|
||||||
|
# 1. Quantize A
|
||||||
|
# 2. Quantize B
|
||||||
|
# 3. Matmul
|
||||||
|
# 4. Mixed-precision decomposition matmul
|
||||||
|
# 5. Save state
|
||||||
|
formatB = state.formatB
|
||||||
|
input_shape = A.shape
|
||||||
|
if state.outlier_pool is None:
|
||||||
|
state.outlier_pool = GlobalOutlierPooler.get_instance()
|
||||||
|
|
||||||
|
# Cast A to fp16
|
||||||
|
if A.dtype != torch.float16:
|
||||||
|
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
|
||||||
|
|
||||||
|
# 1. Quantize A
|
||||||
|
if len(A.shape) == 3:
|
||||||
|
A = A.view(-1, A.shape[-1]).contiguous()
|
||||||
|
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
|
||||||
|
A.to(torch.float16), threshold=state.threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
if state.threshold > 0.0 and coo_tensorA is not None:
|
||||||
|
if state.has_fp16_weights:
|
||||||
|
idx = torch.unique(coo_tensorA.colidx).long()
|
||||||
|
CA[:, idx] = 0
|
||||||
|
CAt[:, idx] = 0
|
||||||
|
subA = A[:, idx]
|
||||||
|
state.subB = B[:, idx].t().contiguous()
|
||||||
|
state.idx = idx
|
||||||
|
else:
|
||||||
|
if state.CxB is None:
|
||||||
|
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
|
||||||
|
# we also need to convert it to the turing/ampere format
|
||||||
|
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
||||||
|
else:
|
||||||
|
#print('A shape', A.shape)
|
||||||
|
if not state.has_fp16_weights and state.CxB is None:
|
||||||
|
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
||||||
|
subA = None
|
||||||
|
|
||||||
|
# 2. Quantize B
|
||||||
|
if state.has_fp16_weights:
|
||||||
|
#print('B shape', B.shape)
|
||||||
|
has_grad = True if (getattr(B, "grad", None) is not None) else False
|
||||||
|
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
|
||||||
|
if is_transposed:
|
||||||
|
B = B.contiguous()
|
||||||
|
|
||||||
|
if (state.is_training and not has_grad) or state.CxB is None:
|
||||||
|
state.reset_grads()
|
||||||
|
(
|
||||||
|
CB,
|
||||||
|
state.CBt,
|
||||||
|
state.SCB,
|
||||||
|
state.SCBt,
|
||||||
|
coo_tensorB,
|
||||||
|
) = F.double_quant(B.to(torch.float16))
|
||||||
|
state.CxB, state.SB = F.transform(CB, to_order=formatB)
|
||||||
|
else:
|
||||||
|
has_grad = False
|
||||||
|
|
||||||
|
if coo_tensorA is not None and not state.has_fp16_weights:
|
||||||
|
# extract outliers
|
||||||
|
|
||||||
|
outlier_idx = torch.unique(coo_tensorA.colidx)
|
||||||
|
state.idx = outlier_idx
|
||||||
|
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
|
||||||
|
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
|
||||||
|
# # do not use pool for 2nd FFN layer
|
||||||
|
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
|
||||||
|
# else:
|
||||||
|
# state.idx = outlier_idx
|
||||||
|
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
|
||||||
|
state.subB = (
|
||||||
|
(outliers * state.SCB.view(-1, 1) / 127.0)
|
||||||
|
.t()
|
||||||
|
.contiguous()
|
||||||
|
.to(A.dtype)
|
||||||
|
)
|
||||||
|
CA[:, state.idx.long()] = 0
|
||||||
|
CAt[:, state.idx.long()] = 0
|
||||||
|
subA = A[:, state.idx.long()]
|
||||||
|
|
||||||
|
shapeB = state.SB[0]
|
||||||
|
|
||||||
|
if len(input_shape) == 3:
|
||||||
|
output_shape = (input_shape[0], input_shape[1], shapeB[0])
|
||||||
|
else:
|
||||||
|
output_shape = (input_shape[0], shapeB[0])
|
||||||
|
|
||||||
|
# 3. Matmul
|
||||||
|
C32A, SA = F.transform(CA, "col32")
|
||||||
|
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
|
||||||
|
# we apply the fused bias here
|
||||||
|
|
||||||
|
if bias is None or bias.dtype == torch.float16:
|
||||||
|
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
|
||||||
|
output = output.to(A.dtype)
|
||||||
|
else: # apply bias separately
|
||||||
|
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
|
||||||
|
output = output.to(A.dtype).add_(bias)
|
||||||
|
|
||||||
|
# 4. Mixed-precision decomposition matmul
|
||||||
|
if coo_tensorA is not None and subA is not None:
|
||||||
|
output += torch.matmul(subA, state.subB)
|
||||||
|
|
||||||
|
# 5. Save state
|
||||||
|
ctx.state = state
|
||||||
|
|
||||||
|
ctx.formatB = formatB
|
||||||
|
ctx.grad_shape = input_shape
|
||||||
|
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
|
||||||
|
|
||||||
|
if any(ctx.needs_input_grad[:2]):
|
||||||
|
ctx.tensors = (CAt, subA, A)
|
||||||
|
ctx.tensor_states = (SCAt, state.idx)
|
||||||
|
else:
|
||||||
|
ctx.tensors = [None, None, None]
|
||||||
|
ctx.tensor_states = (None, None)
|
||||||
|
ctx.save_for_backward(None, None)
|
||||||
|
|
||||||
|
|
||||||
|
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
|
||||||
|
return clone_func(output.view(output_shape))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
if ctx.is_empty:
|
||||||
|
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
|
||||||
|
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
|
||||||
|
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
|
||||||
|
CAt, subA, A = ctx.tensors
|
||||||
|
SCAt, idx = ctx.tensor_states
|
||||||
|
formatB = ctx.formatB
|
||||||
|
state = ctx.state
|
||||||
|
grad_A = grad_B = grad_bias = None
|
||||||
|
|
||||||
|
if req_gradBias:
|
||||||
|
# compute grad_bias first before changing grad_output dtype
|
||||||
|
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
|
||||||
|
|
||||||
|
# Cast grad_output to fp16
|
||||||
|
if len(grad_output.shape) == 3:
|
||||||
|
grad_output = grad_output.reshape(
|
||||||
|
-1, grad_output.shape[-1]
|
||||||
|
).contiguous()
|
||||||
|
|
||||||
|
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
|
||||||
|
|
||||||
|
if req_gradB:
|
||||||
|
# print('back A shape', A.shape)
|
||||||
|
# print('grad output t shape', grad_output.t().shape)
|
||||||
|
grad_B = torch.matmul(grad_output.t(), A)
|
||||||
|
|
||||||
|
if req_gradA:
|
||||||
|
if state.CBt is not None:
|
||||||
|
C32grad, Sgrad = F.transform(Cgrad, "col32")
|
||||||
|
if state.CxBt is None:
|
||||||
|
state.CxBt, state.SBt = F.transform(
|
||||||
|
state.CBt, to_order=formatB, transpose=True
|
||||||
|
)
|
||||||
|
# print('back B shape', state.CxBt.shape)
|
||||||
|
# print('back grad shape', C32grad.shape)
|
||||||
|
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
|
||||||
|
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||||
|
|
||||||
|
elif state.CB is not None:
|
||||||
|
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0))
|
||||||
|
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||||
|
else:
|
||||||
|
raise Exception('State must contain either CBt or CB matrix for backward')
|
||||||
|
|
||||||
|
return grad_A, grad_B, None, grad_bias, None
|
||||||
|
|
||||||
|
def get_block_sizes(input_matrix, weight_matrix):
|
||||||
|
input_features = input_matrix.shape[-1]
|
||||||
|
output_features = (weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1])
|
||||||
|
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
|
||||||
|
bsz, bsz2 = 1024, 1024
|
||||||
|
for i, k in enumerate(array):
|
||||||
|
if input_features > array[i + 1]:
|
||||||
|
bsz = k
|
||||||
|
break
|
||||||
|
for i, k in enumerate(array):
|
||||||
|
if output_features > array[i + 1]:
|
||||||
|
bsz2 = k
|
||||||
|
break
|
||||||
|
|
||||||
|
return bsz, bsz2
|
||||||
|
|
||||||
|
|
||||||
|
def matmul_fp8(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
|
||||||
|
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
|
||||||
|
return MatMulFP8.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
|
||||||
|
|
||||||
|
def matmul_fp8_global(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
|
||||||
|
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
|
||||||
|
return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
|
||||||
|
|
||||||
|
def matmul_fp8_mixed(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
|
||||||
|
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
|
||||||
|
return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
|
||||||
|
|
||||||
|
|
||||||
|
def switchback_bnb(
|
||||||
|
A: tensor,
|
||||||
|
B: tensor,
|
||||||
|
out: tensor = None,
|
||||||
|
state: MatmulLtState = None,
|
||||||
|
threshold=0.0,
|
||||||
|
bias=None
|
||||||
|
):
|
||||||
|
state = state or MatmulLtState()
|
||||||
|
if threshold > 0.0:
|
||||||
|
state.threshold = threshold
|
||||||
|
return MatMul8bitMixed.apply(A, B, out, bias, state)
|
0
bitsandbytes/triton/__init__.py
Normal file
0
bitsandbytes/triton/__init__.py
Normal file
58
bitsandbytes/triton/dequantize_rowwise.py
Normal file
58
bitsandbytes/triton/dequantize_rowwise.py
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||||
|
|
||||||
|
# rowwise quantize
|
||||||
|
|
||||||
|
# TODO: autotune this better.
|
||||||
|
@triton.autotune(
|
||||||
|
configs=[
|
||||||
|
triton.Config({}, num_stages=1, num_warps=8),
|
||||||
|
triton.Config({}, num_stages=2, num_warps=8),
|
||||||
|
triton.Config({}, num_stages=4, num_warps=8),
|
||||||
|
triton.Config({}, num_stages=8, num_warps=8),
|
||||||
|
triton.Config({}, num_stages=1),
|
||||||
|
triton.Config({}, num_stages=2),
|
||||||
|
triton.Config({}, num_stages=4),
|
||||||
|
triton.Config({}, num_stages=8),
|
||||||
|
triton.Config({}, num_warps=1),
|
||||||
|
triton.Config({}, num_warps=2),
|
||||||
|
triton.Config({}, num_warps=4),
|
||||||
|
triton.Config({}, num_warps=8),
|
||||||
|
],
|
||||||
|
key=['n_elements']
|
||||||
|
)
|
||||||
|
@triton.jit
|
||||||
|
def _dequantize_rowwise(
|
||||||
|
x_ptr,
|
||||||
|
state_x,
|
||||||
|
output_ptr,
|
||||||
|
inv_127,
|
||||||
|
n_elements,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
P2: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
block_start = pid * BLOCK_SIZE
|
||||||
|
arange = tl.arange(0, P2)
|
||||||
|
offsets = block_start + arange
|
||||||
|
row_mask = arange < BLOCK_SIZE
|
||||||
|
x = tl.load(x_ptr + offsets, mask=row_mask)
|
||||||
|
max_val = tl.load(state_x + pid)
|
||||||
|
output = max_val * x * inv_127
|
||||||
|
tl.store(output_ptr + offsets, output, mask=row_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
|
||||||
|
output = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
|
||||||
|
|
||||||
|
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
|
||||||
|
|
||||||
|
assert x.is_cuda and output.is_cuda
|
||||||
|
n_elements = output.numel()
|
||||||
|
grid = lambda meta: (x.shape[0],)
|
||||||
|
_dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
|
||||||
|
return output
|
158
bitsandbytes/triton/int8_matmul_mixed_dequanitze.py
Normal file
158
bitsandbytes/triton/int8_matmul_mixed_dequanitze.py
Normal file
|
@ -0,0 +1,158 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||||
|
|
||||||
|
|
||||||
|
# This is a matmul kernel based on triton.ops.matmul
|
||||||
|
# It is modified to support rowwise quantized input and global quantized weight
|
||||||
|
# It's purpose is fused matmul then dequantize
|
||||||
|
# It does support bias.
|
||||||
|
|
||||||
|
def init_to_zero(name):
|
||||||
|
return lambda nargs: nargs[name].zero_()
|
||||||
|
|
||||||
|
def get_configs_io_bound():
|
||||||
|
configs = []
|
||||||
|
for num_stages in [2, 3, 4, 5, 6]:
|
||||||
|
for block_m in [16, 32]:
|
||||||
|
for block_k in [32, 64]:
|
||||||
|
for block_n in [32, 64, 128, 256]:
|
||||||
|
num_warps = 2 if block_n <= 64 else 4
|
||||||
|
configs.append(
|
||||||
|
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||||
|
num_stages=num_stages, num_warps=num_warps))
|
||||||
|
# split_k
|
||||||
|
for split_k in [2, 4, 8, 16]:
|
||||||
|
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||||
|
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||||
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
@triton.autotune(
|
||||||
|
configs=[
|
||||||
|
# basic configs for compute-bound matmuls
|
||||||
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||||
|
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||||
|
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||||
|
# good for int8
|
||||||
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||||
|
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||||
|
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||||
|
] + get_configs_io_bound(),
|
||||||
|
key=['M', 'N', 'K'],
|
||||||
|
prune_configs_by={
|
||||||
|
'early_config_prune': early_config_prune,
|
||||||
|
'perf_model': estimate_matmul_time,
|
||||||
|
'top_k': 10
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@triton.heuristics({
|
||||||
|
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||||
|
})
|
||||||
|
@triton.jit
|
||||||
|
def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr,
|
||||||
|
stride_am, stride_ak,
|
||||||
|
stride_bk, stride_bn,
|
||||||
|
stride_cm, stride_cn,
|
||||||
|
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||||
|
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
|
||||||
|
ACC_TYPE: tl.constexpr
|
||||||
|
):
|
||||||
|
# matrix multiplication
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
pid_z = tl.program_id(1)
|
||||||
|
grid_m = tl.cdiv(M, BLOCK_M)
|
||||||
|
grid_n = tl.cdiv(N, BLOCK_N)
|
||||||
|
# re-order program ID for better L2 performance
|
||||||
|
width = GROUP_M * grid_n
|
||||||
|
group_id = pid // width
|
||||||
|
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||||
|
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||||
|
pid_n = (pid % width) // (group_size)
|
||||||
|
# do matrix multiplication
|
||||||
|
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
|
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||||
|
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||||
|
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||||
|
# pointers
|
||||||
|
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||||
|
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||||
|
|
||||||
|
# rematerialize rm and rn to save registers
|
||||||
|
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
|
|
||||||
|
w_factor = tl.load(state_w_ptr)
|
||||||
|
x_factor = tl.load(state_x_ptr + ram)[:, None]
|
||||||
|
|
||||||
|
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||||
|
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
|
||||||
|
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||||
|
if EVEN_K:
|
||||||
|
a = tl.load(A)
|
||||||
|
b = tl.load(B)
|
||||||
|
else:
|
||||||
|
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
||||||
|
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
|
||||||
|
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
|
||||||
|
acc += tl.dot(a, b)
|
||||||
|
A += BLOCK_K * SPLIT_K * stride_ak
|
||||||
|
B += BLOCK_K * SPLIT_K * stride_bk
|
||||||
|
|
||||||
|
acc = (w_factor * (x_factor * (acc * divfactor)))
|
||||||
|
acc = acc.to(C.dtype.element_ty)
|
||||||
|
|
||||||
|
# conditionally add bias
|
||||||
|
if has_bias:
|
||||||
|
bias = tl.load(bias + rn).to(C.dtype.element_ty)
|
||||||
|
acc = acc + bias[None, :]
|
||||||
|
|
||||||
|
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||||
|
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||||
|
# handles write-back with reduction-splitting
|
||||||
|
if SPLIT_K == 1:
|
||||||
|
tl.store(C, acc, mask=mask)
|
||||||
|
else:
|
||||||
|
tl.atomic_add(C, acc, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias):
|
||||||
|
device = a.device
|
||||||
|
divfactor = 1. / (127. * 127.)
|
||||||
|
has_bias = 0 if bias is None else 1
|
||||||
|
# handle non-contiguous inputs if necessary
|
||||||
|
if a.stride(0) > 1 and a.stride(1) > 1:
|
||||||
|
a = a.contiguous()
|
||||||
|
if b.stride(0) > 1 and b.stride(1) > 1:
|
||||||
|
b = b.contiguous()
|
||||||
|
# checks constraints
|
||||||
|
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||||
|
M, K = a.shape
|
||||||
|
_, N = b.shape
|
||||||
|
# allocates output
|
||||||
|
c = torch.empty((M, N), device=device, dtype=torch.float16)
|
||||||
|
# accumulator types
|
||||||
|
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
|
||||||
|
# launch int8_matmul_mixed_dequantize kernel
|
||||||
|
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||||
|
_int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
|
||||||
|
a.stride(0), a.stride(1),
|
||||||
|
b.stride(0), b.stride(1),
|
||||||
|
c.stride(0), c.stride(1),
|
||||||
|
GROUP_M=8, ACC_TYPE=ACC_TYPE)
|
||||||
|
return c
|
159
bitsandbytes/triton/int8_matmul_rowwise_dequantize.py
Normal file
159
bitsandbytes/triton/int8_matmul_rowwise_dequantize.py
Normal file
|
@ -0,0 +1,159 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||||
|
|
||||||
|
# This is a matmul kernel based on triton.ops.matmul
|
||||||
|
# It is modified to support rowwise quantized input and columnwise quantized weight
|
||||||
|
# It's purpose is fused matmul then dequantize
|
||||||
|
# It does support bias.
|
||||||
|
|
||||||
|
def init_to_zero(name):
|
||||||
|
return lambda nargs: nargs[name].zero_()
|
||||||
|
|
||||||
|
|
||||||
|
def get_configs_io_bound():
|
||||||
|
configs = []
|
||||||
|
for num_stages in [2, 3, 4, 5, 6]:
|
||||||
|
for block_m in [16, 32]:
|
||||||
|
for block_k in [32, 64]:
|
||||||
|
for block_n in [32, 64, 128, 256]:
|
||||||
|
num_warps = 2 if block_n <= 64 else 4
|
||||||
|
configs.append(
|
||||||
|
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||||
|
num_stages=num_stages, num_warps=num_warps))
|
||||||
|
# split_k
|
||||||
|
for split_k in [2, 4, 8, 16]:
|
||||||
|
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||||
|
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||||
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
@triton.autotune(
|
||||||
|
configs=[
|
||||||
|
# basic configs for compute-bound matmuls
|
||||||
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||||
|
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||||
|
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||||
|
# good for int8
|
||||||
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||||
|
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||||
|
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||||
|
] + get_configs_io_bound(),
|
||||||
|
key=['M', 'N', 'K'],
|
||||||
|
prune_configs_by={
|
||||||
|
'early_config_prune': early_config_prune,
|
||||||
|
'perf_model': estimate_matmul_time,
|
||||||
|
'top_k': 10
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@triton.heuristics({
|
||||||
|
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||||
|
})
|
||||||
|
@triton.jit
|
||||||
|
def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr,
|
||||||
|
stride_am, stride_ak,
|
||||||
|
stride_bk, stride_bn,
|
||||||
|
stride_cm, stride_cn,
|
||||||
|
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||||
|
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
|
||||||
|
ACC_TYPE: tl.constexpr
|
||||||
|
):
|
||||||
|
# matrix multiplication
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
pid_z = tl.program_id(1)
|
||||||
|
grid_m = tl.cdiv(M, BLOCK_M)
|
||||||
|
grid_n = tl.cdiv(N, BLOCK_N)
|
||||||
|
# re-order program ID for better L2 performance
|
||||||
|
width = GROUP_M * grid_n
|
||||||
|
group_id = pid // width
|
||||||
|
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||||
|
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||||
|
pid_n = (pid % width) // (group_size)
|
||||||
|
# do matrix multiplication
|
||||||
|
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
|
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||||
|
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||||
|
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||||
|
# pointers
|
||||||
|
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||||
|
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||||
|
|
||||||
|
# rematerialize rm and rn to save registers
|
||||||
|
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
|
|
||||||
|
w_factor = tl.load(state_w_ptr + rbn)[None, :]
|
||||||
|
x_factor = tl.load(state_x_ptr + ram)[:, None]
|
||||||
|
|
||||||
|
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||||
|
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
|
||||||
|
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||||
|
if EVEN_K:
|
||||||
|
a = tl.load(A)
|
||||||
|
b = tl.load(B)
|
||||||
|
else:
|
||||||
|
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
||||||
|
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
|
||||||
|
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
|
||||||
|
acc += tl.dot(a, b)
|
||||||
|
A += BLOCK_K * SPLIT_K * stride_ak
|
||||||
|
B += BLOCK_K * SPLIT_K * stride_bk
|
||||||
|
|
||||||
|
acc = (w_factor * (x_factor * (acc * divfactor)))
|
||||||
|
acc = acc.to(C.dtype.element_ty)
|
||||||
|
|
||||||
|
if has_bias:
|
||||||
|
bias = tl.load(bias + rn).to(C.dtype.element_ty)
|
||||||
|
acc = acc + bias[None, :]
|
||||||
|
|
||||||
|
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||||
|
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||||
|
# handles write-back with reduction-splitting
|
||||||
|
if SPLIT_K == 1:
|
||||||
|
tl.store(C, acc, mask=mask)
|
||||||
|
else:
|
||||||
|
tl.atomic_add(C, acc, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
|
||||||
|
divfactor = 1. / (127. * 127.)
|
||||||
|
|
||||||
|
has_bias = 0 if bias is None else 1
|
||||||
|
|
||||||
|
device = a.device
|
||||||
|
# handle non-contiguous inputs if necessary
|
||||||
|
if a.stride(0) > 1 and a.stride(1) > 1:
|
||||||
|
a = a.contiguous()
|
||||||
|
if b.stride(0) > 1 and b.stride(1) > 1:
|
||||||
|
b = b.contiguous()
|
||||||
|
# checks constraints
|
||||||
|
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||||
|
M, K = a.shape
|
||||||
|
_, N = b.shape
|
||||||
|
# allocates output
|
||||||
|
c = torch.empty((M, N), device=device, dtype=torch.float16)
|
||||||
|
# accumulator types
|
||||||
|
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
|
||||||
|
# launch int8_matmul_rowwise_dequantize kernel
|
||||||
|
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||||
|
_int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
|
||||||
|
a.stride(0), a.stride(1),
|
||||||
|
b.stride(0), b.stride(1),
|
||||||
|
c.stride(0), c.stride(1),
|
||||||
|
GROUP_M=8, ACC_TYPE=ACC_TYPE)
|
||||||
|
return c
|
68
bitsandbytes/triton/quantize_columnwise_and_transpose.py
Normal file
68
bitsandbytes/triton/quantize_columnwise_and_transpose.py
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||||
|
|
||||||
|
# This kernel does fused columnwise quantization and transpose.
|
||||||
|
|
||||||
|
# TODO: autotune this better.
|
||||||
|
@triton.autotune(
|
||||||
|
configs=[
|
||||||
|
triton.Config({}, num_stages=1),
|
||||||
|
triton.Config({}, num_stages=2),
|
||||||
|
triton.Config({}, num_stages=4),
|
||||||
|
triton.Config({}, num_stages=8),
|
||||||
|
triton.Config({}, num_stages=16),
|
||||||
|
triton.Config({}, num_stages=1, num_warps=8),
|
||||||
|
triton.Config({}, num_stages=2, num_warps=8),
|
||||||
|
triton.Config({}, num_stages=4, num_warps=8),
|
||||||
|
triton.Config({}, num_stages=8, num_warps=8),
|
||||||
|
triton.Config({}, num_stages=16, num_warps=8),
|
||||||
|
triton.Config({}, num_warps=1),
|
||||||
|
triton.Config({}, num_warps=2),
|
||||||
|
triton.Config({}, num_warps=4),
|
||||||
|
triton.Config({}, num_warps=8),
|
||||||
|
],
|
||||||
|
key=['n_elements']
|
||||||
|
)
|
||||||
|
@triton.jit
|
||||||
|
def _quantize_columnwise_and_transpose(
|
||||||
|
x_ptr,
|
||||||
|
output_ptr,
|
||||||
|
output_maxs,
|
||||||
|
n_elements,
|
||||||
|
M : tl.constexpr, N : tl.constexpr,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
P2: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
block_start = pid
|
||||||
|
p2_arange = tl.arange(0, P2)
|
||||||
|
p2_arange_mask = p2_arange < M
|
||||||
|
arange = p2_arange * N
|
||||||
|
offsets = block_start + arange
|
||||||
|
x = tl.load(x_ptr + offsets, mask=p2_arange_mask)
|
||||||
|
abs_x = tl.abs(x)
|
||||||
|
max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0)
|
||||||
|
output = tl.libdevice.llrint(127. * (x / max_val))
|
||||||
|
|
||||||
|
new_start = pid * M
|
||||||
|
new_offsets = new_start + p2_arange
|
||||||
|
tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask)
|
||||||
|
tl.store(output_maxs + pid, max_val)
|
||||||
|
|
||||||
|
def quantize_columnwise_and_transpose(x: torch.Tensor):
|
||||||
|
M, N = x.shape
|
||||||
|
output = torch.empty(N, M, device=x.device, dtype=torch.int8)
|
||||||
|
output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16)
|
||||||
|
|
||||||
|
P2 = int(2 ** (math.ceil(math.log2(M))))
|
||||||
|
|
||||||
|
assert x.is_cuda and output.is_cuda
|
||||||
|
n_elements = output.numel()
|
||||||
|
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||||
|
_quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2)
|
||||||
|
return output, output_maxs
|
||||||
|
|
100
bitsandbytes/triton/quantize_global.py
Normal file
100
bitsandbytes/triton/quantize_global.py
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||||
|
|
||||||
|
# global quantize
|
||||||
|
@triton.autotune(
|
||||||
|
configs=[
|
||||||
|
triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1),
|
||||||
|
|
||||||
|
],
|
||||||
|
key=['n_elements']
|
||||||
|
)
|
||||||
|
@triton.jit
|
||||||
|
def _quantize_global(
|
||||||
|
x_ptr,
|
||||||
|
absmax_inv_ptr,
|
||||||
|
output_ptr,
|
||||||
|
n_elements,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
block_start = pid * BLOCK_SIZE
|
||||||
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = offsets < n_elements
|
||||||
|
x = tl.load(x_ptr + offsets, mask=mask)
|
||||||
|
absmax_inv = tl.load(absmax_inv_ptr)
|
||||||
|
output = tl.libdevice.llrint(127. * (x * absmax_inv))
|
||||||
|
tl.store(output_ptr + offsets, output, mask=mask)
|
||||||
|
|
||||||
|
def quantize_global(x: torch.Tensor):
|
||||||
|
absmax = x.abs().max().unsqueeze(0)
|
||||||
|
absmax_inv = 1./ absmax
|
||||||
|
output = torch.empty(*x.shape, device='cuda', dtype=torch.int8)
|
||||||
|
assert x.is_cuda and output.is_cuda
|
||||||
|
n_elements = output.numel()
|
||||||
|
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||||
|
_quantize_global[grid](x, absmax_inv, output, n_elements)
|
||||||
|
return output, absmax
|
||||||
|
|
||||||
|
|
||||||
|
# global quantize and transpose
|
||||||
|
@triton.autotune(
|
||||||
|
configs=[
|
||||||
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
|
||||||
|
|
||||||
|
# ...
|
||||||
|
],
|
||||||
|
key=['M', 'N']
|
||||||
|
)
|
||||||
|
@triton.jit
|
||||||
|
def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N,
|
||||||
|
BLOCK_M : tl.constexpr,
|
||||||
|
BLOCK_N : tl.constexpr,
|
||||||
|
GROUP_M : tl.constexpr):
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||||
|
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||||
|
|
||||||
|
width = GROUP_M * grid_n
|
||||||
|
group_id = pid // width
|
||||||
|
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||||
|
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||||
|
pid_n = (pid % width) // group_size
|
||||||
|
|
||||||
|
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
|
A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an)
|
||||||
|
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||||
|
a = tl.load(A, mask=mask)
|
||||||
|
absmax_inv = tl.load(absmax_inv_ptr)
|
||||||
|
|
||||||
|
# rematerialize to save registers
|
||||||
|
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
|
B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn)
|
||||||
|
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||||
|
|
||||||
|
output = tl.libdevice.llrint(127. * (a * absmax_inv))
|
||||||
|
|
||||||
|
tl.store(B, output, mask=mask)
|
||||||
|
|
||||||
|
def quantize_global_transpose(input):
|
||||||
|
absmax = input.abs().max().unsqueeze(0)
|
||||||
|
absmax_inv = 1./ absmax
|
||||||
|
M, N = input.shape
|
||||||
|
out = torch.empty(N, M, device='cuda', dtype=torch.int8)
|
||||||
|
|
||||||
|
assert out.size(0) == N and out.size(1) == M
|
||||||
|
assert input.stride(0) == 1 or input.stride(1) == 1
|
||||||
|
assert out.stride(0) == 1 or out.stride(1) == 1
|
||||||
|
|
||||||
|
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
|
||||||
|
_quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N)
|
||||||
|
return out, absmax
|
||||||
|
|
61
bitsandbytes/triton/quantize_rowwise.py
Normal file
61
bitsandbytes/triton/quantize_rowwise.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||||
|
|
||||||
|
# rowwise quantize
|
||||||
|
|
||||||
|
# TODO: autotune this better.
|
||||||
|
@triton.autotune(
|
||||||
|
configs=[
|
||||||
|
triton.Config({}, num_stages=1, num_warps=8),
|
||||||
|
triton.Config({}, num_stages=2, num_warps=8),
|
||||||
|
triton.Config({}, num_stages=4, num_warps=8),
|
||||||
|
triton.Config({}, num_stages=8, num_warps=8),
|
||||||
|
triton.Config({}, num_stages=1),
|
||||||
|
triton.Config({}, num_stages=2),
|
||||||
|
triton.Config({}, num_stages=4),
|
||||||
|
triton.Config({}, num_stages=8),
|
||||||
|
triton.Config({}, num_warps=1),
|
||||||
|
triton.Config({}, num_warps=2),
|
||||||
|
triton.Config({}, num_warps=4),
|
||||||
|
triton.Config({}, num_warps=8),
|
||||||
|
],
|
||||||
|
key=['n_elements']
|
||||||
|
)
|
||||||
|
@triton.jit
|
||||||
|
def _quantize_rowwise(
|
||||||
|
x_ptr,
|
||||||
|
output_ptr,
|
||||||
|
output_maxs,
|
||||||
|
n_elements,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
P2: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
block_start = pid * BLOCK_SIZE
|
||||||
|
arange = tl.arange(0, P2)
|
||||||
|
offsets = block_start + arange
|
||||||
|
row_mask = arange < BLOCK_SIZE
|
||||||
|
x = tl.load(x_ptr + offsets, mask=row_mask)
|
||||||
|
|
||||||
|
abs_x = tl.abs(x)
|
||||||
|
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
|
||||||
|
output = tl.libdevice.llrint(127. * (x / max_val))
|
||||||
|
tl.store(output_ptr + offsets, output, mask=row_mask)
|
||||||
|
tl.store(output_maxs + pid, max_val)
|
||||||
|
|
||||||
|
def quantize_rowwise(x: torch.Tensor):
|
||||||
|
output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)
|
||||||
|
output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)
|
||||||
|
|
||||||
|
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
|
||||||
|
|
||||||
|
assert x.is_cuda and output.is_cuda
|
||||||
|
n_elements = output.numel()
|
||||||
|
grid = lambda meta: (x.shape[0],)
|
||||||
|
_quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
|
||||||
|
return output, output_maxs
|
||||||
|
|
Loading…
Reference in New Issue
Block a user