diff --git a/bitsandbytes/research/__init__.py b/bitsandbytes/research/__init__.py new file mode 100644 index 0000000..f5ab510 --- /dev/null +++ b/bitsandbytes/research/__init__.py @@ -0,0 +1,7 @@ + +from .autograd._functions import ( + matmul_fp8, + switchback_bnb, + matmul_fp8_global, + matmul_fp8_mixed, +) diff --git a/bitsandbytes/research/autograd/__init__.py b/bitsandbytes/research/autograd/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py new file mode 100644 index 0000000..b0a098d --- /dev/null +++ b/bitsandbytes/research/autograd/_functions.py @@ -0,0 +1,493 @@ +import operator +import warnings +from dataclasses import dataclass +from functools import reduce # Required in Python 3 + +import torch + +import bitsandbytes.functional as F + +from bitsandbytes.autograd._functions import MatmulLtState, GlobalOutlierPooler + + +# math.prod not compatible with python < 3.8 +def prod(iterable): + return reduce(operator.mul, iterable, 1) + +tensor = torch.Tensor + +class MatMulFP8(torch.autograd.Function): + # forward is the same, but we added the fallback for pre-turing GPUs + # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") + + @staticmethod + def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024): + # default of pytorch behavior if inputs are empty + ctx.is_empty = False + if prod(A.shape) == 0: + ctx.is_empty = True + ctx.A = A + ctx.B = B + + B_shape = B.shape + if A.shape[-1] == B_shape[0]: + return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device) + else: + return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) + + # 1. Dequantize + # 2. MatmulnN + cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=bsz) + fp8A = F.dequantize_blockwise(cA, state, blocksize=bsz).to(A.dtype) + + cB, state = F.quantize(B.float(), code=fw_code) + fp8B = F.dequantize(cB, state).to(B.dtype) + + output = torch.matmul(fp8A, fp8B) + + # output is half + + # 3. Save state + ctx.fw_code = fw_code + ctx.bw_code = bw_code + ctx.bsz = bsz + ctx.bsz2 = bsz2 + ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype + + if any(ctx.needs_input_grad[:2]): + # NOTE: we send back A, and re-quant. + ctx.tensors = (A, fp8B) + else: + ctx.tensors = (None, None) + + return output + + @staticmethod + def backward(ctx, grad_output): + if ctx.is_empty: + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None, None + + req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad + A, B = ctx.tensors + + grad_A, grad_B = None, None + + cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz2) + fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz2).to(grad_output.dtype) + + cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code) + fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype) + + # grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() + # fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector') + # fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose + # fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2]) + + # not supported by PyTorch. TODO: create work-around + if req_gradA: + grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype) + + if req_gradB: + if len(A.shape) == 3: + At = A.transpose(2, 1).contiguous() + else: + At = A.transpose(1, 0).contiguous() + cA, state = F.quantize(At.float(), code=ctx.fw_code) + fp8At = F.dequantize(cA, state).to(A.dtype) + grad_B = torch.matmul(fp8At.to(fp8out_2.dtype), fp8out_2).to(B.dtype) + + return grad_A, grad_B, None, None, None, None, None + +class MatMulFP8Mixed(torch.autograd.Function): + # forward is the same, but we added the fallback for pre-turing GPUs + # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") + + @staticmethod + def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024): + # default of pytorch behavior if inputs are empty + ctx.is_empty = False + if prod(A.shape) == 0: + ctx.is_empty = True + ctx.A = A + ctx.B = B + + B_shape = B.shape + if A.shape[-1] == B_shape[0]: + return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device) + else: + return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) + + # 1. Dequantize + # 2. MatmulnN + cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=bsz) + fp8A = F.dequantize_blockwise(cA, state, blocksize=bsz).to(A.dtype) + + cB, state = F.quantize(B.float(), code=fw_code) + fp8B = F.dequantize(cB, state).to(B.dtype) + + output = torch.matmul(fp8A, fp8B) + + # output is half + + # 3. Save state + ctx.fw_code = fw_code + ctx.bw_code = bw_code + ctx.bsz = bsz + ctx.bsz2 = bsz2 + ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype + + if any(ctx.needs_input_grad[:2]): + # NOTE: we send back A, and re-quant. + ctx.tensors = (A, fp8B) + else: + ctx.tensors = (None, None) + + return output + + @staticmethod + def backward(ctx, grad_output): + if ctx.is_empty: + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None, None + + req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad + A, B = ctx.tensors + + grad_A, grad_B = None, None + + # TODO: Fix blocksize to be output_dim + cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz2) + fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz2).to(grad_output.dtype) + + # cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code) + # fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype) + + # grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() + # fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector') + # fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose + # fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2]) + + # not supported by PyTorch. TODO: create work-around + if req_gradA: + grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype) + + if req_gradB: + At = A.transpose(2, 1).contiguous() + # cA, state = F.quantize(At.float(), code=ctx.fw_code) + # fp8At = F.dequantize(cA, state).to(A.dtype) + grad_B = torch.matmul(At.to(grad_output.dtype), grad_output).to(B.dtype) + + return grad_A, grad_B, None, None, None, None, None + + +class MatMulFP8Global(torch.autograd.Function): + # forward is the same, but we added the fallback for pre-turing GPUs + # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") + + @staticmethod + def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024): + # default of pytorch behavior if inputs are empty + ctx.is_empty = False + if prod(A.shape) == 0: + ctx.is_empty = True + ctx.A = A + ctx.B = B + + B_shape = B.shape + if A.shape[-1] == B_shape[0]: + return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device) + else: + return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) + + # 1. Dequantize + # 2. MatmulnN + cA, state = F.quantize(A.float(), code=fw_code) + fp8A = F.dequantize(cA, state).to(A.dtype) + + cB, state = F.quantize(B.float(), code=fw_code) + fp8B = F.dequantize(cB, state).to(B.dtype) + + output = torch.matmul(fp8A, fp8B) + + # output is half + + # 3. Save state + ctx.fw_code = fw_code + ctx.bw_code = bw_code + ctx.bsz = bsz + ctx.bsz2 = bsz2 + ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype + + if any(ctx.needs_input_grad[:2]): + # NOTE: we send back A, and re-quant. + ctx.tensors = (A, fp8B) + else: + ctx.tensors = (None, None) + + return output + + @staticmethod + def backward(ctx, grad_output): + if ctx.is_empty: + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None, None + + req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad + A, B = ctx.tensors + + grad_A, grad_B = None, None + + # TODO: Fix blocksize to be output_dim + cgrad_out, state = F.quantize(grad_output.float(), code=ctx.bw_code) + fp8out = F.dequantize(cgrad_out, state).to(grad_output.dtype) + + # cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code) + # fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype) + + # grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() + # fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector') + # fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose + # fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2]) + + # not supported by PyTorch. TODO: create work-around + if req_gradA: + grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype) + + if req_gradB: + At = A.transpose(2, 1).contiguous() + cA, state = F.quantize(At.float(), code=ctx.fw_code) + fp8At = F.dequantize(cA, state).to(A.dtype) + grad_B = torch.matmul(fp8At.to(fp8out.dtype), fp8out).to(B.dtype) + + return grad_A, grad_B, None, None, None, None, None + + +class MatMul8bitMixed(torch.autograd.Function): + @staticmethod + def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): + # default to pytorch behavior if inputs are empty + ctx.is_empty = False + if prod(A.shape) == 0: + ctx.is_empty = True + ctx.A = A + ctx.B = B + ctx.bias = bias + if A.shape[-1] == B.shape[0]: + return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device) + else: + return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device) + + # 1. Quantize A + # 2. Quantize B + # 3. Matmul + # 4. Mixed-precision decomposition matmul + # 5. Save state + formatB = state.formatB + input_shape = A.shape + if state.outlier_pool is None: + state.outlier_pool = GlobalOutlierPooler.get_instance() + + # Cast A to fp16 + if A.dtype != torch.float16: + warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") + + # 1. Quantize A + if len(A.shape) == 3: + A = A.view(-1, A.shape[-1]).contiguous() + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( + A.to(torch.float16), threshold=state.threshold + ) + + if state.threshold > 0.0 and coo_tensorA is not None: + if state.has_fp16_weights: + idx = torch.unique(coo_tensorA.colidx).long() + CA[:, idx] = 0 + CAt[:, idx] = 0 + subA = A[:, idx] + state.subB = B[:, idx].t().contiguous() + state.idx = idx + else: + if state.CxB is None: + # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions + # we also need to convert it to the turing/ampere format + state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + else: + #print('A shape', A.shape) + if not state.has_fp16_weights and state.CxB is None: + state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + subA = None + + # 2. Quantize B + if state.has_fp16_weights: + #print('B shape', B.shape) + has_grad = True if (getattr(B, "grad", None) is not None) else False + is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1) + if is_transposed: + B = B.contiguous() + + if (state.is_training and not has_grad) or state.CxB is None: + state.reset_grads() + ( + CB, + state.CBt, + state.SCB, + state.SCBt, + coo_tensorB, + ) = F.double_quant(B.to(torch.float16)) + state.CxB, state.SB = F.transform(CB, to_order=formatB) + else: + has_grad = False + + if coo_tensorA is not None and not state.has_fp16_weights: + # extract outliers + + outlier_idx = torch.unique(coo_tensorA.colidx) + state.idx = outlier_idx + # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) + # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: + # # do not use pool for 2nd FFN layer + # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) + # else: + # state.idx = outlier_idx + outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) + state.subB = ( + (outliers * state.SCB.view(-1, 1) / 127.0) + .t() + .contiguous() + .to(A.dtype) + ) + CA[:, state.idx.long()] = 0 + CAt[:, state.idx.long()] = 0 + subA = A[:, state.idx.long()] + + shapeB = state.SB[0] + + if len(input_shape) == 3: + output_shape = (input_shape[0], input_shape[1], shapeB[0]) + else: + output_shape = (input_shape[0], shapeB[0]) + + # 3. Matmul + C32A, SA = F.transform(CA, "col32") + out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) + # we apply the fused bias here + + if bias is None or bias.dtype == torch.float16: + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) + output = output.to(A.dtype) + else: # apply bias separately + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) + output = output.to(A.dtype).add_(bias) + + # 4. Mixed-precision decomposition matmul + if coo_tensorA is not None and subA is not None: + output += torch.matmul(subA, state.subB) + + # 5. Save state + ctx.state = state + + ctx.formatB = formatB + ctx.grad_shape = input_shape + ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype + + if any(ctx.needs_input_grad[:2]): + ctx.tensors = (CAt, subA, A) + ctx.tensor_states = (SCAt, state.idx) + else: + ctx.tensors = [None, None, None] + ctx.tensor_states = (None, None) + ctx.save_for_backward(None, None) + + + clone_func = torch.clone if len(output_shape) == 3 else lambda x : x + return clone_func(output.view(output_shape)) + + @staticmethod + def backward(ctx, grad_output): + if ctx.is_empty: + bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None + req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad + CAt, subA, A = ctx.tensors + SCAt, idx = ctx.tensor_states + formatB = ctx.formatB + state = ctx.state + grad_A = grad_B = grad_bias = None + + if req_gradBias: + # compute grad_bias first before changing grad_output dtype + grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) + + # Cast grad_output to fp16 + if len(grad_output.shape) == 3: + grad_output = grad_output.reshape( + -1, grad_output.shape[-1] + ).contiguous() + + Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) + + if req_gradB: + # print('back A shape', A.shape) + # print('grad output t shape', grad_output.t().shape) + grad_B = torch.matmul(grad_output.t(), A) + + if req_gradA: + if state.CBt is not None: + C32grad, Sgrad = F.transform(Cgrad, "col32") + if state.CxBt is None: + state.CxBt, state.SBt = F.transform( + state.CBt, to_order=formatB, transpose=True + ) + # print('back B shape', state.CxBt.shape) + # print('back grad shape', C32grad.shape) + gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) + grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) + + elif state.CB is not None: + CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0)) + grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) + else: + raise Exception('State must contain either CBt or CB matrix for backward') + + return grad_A, grad_B, None, grad_bias, None + +def get_block_sizes(input_matrix, weight_matrix): + input_features = input_matrix.shape[-1] + output_features = (weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1]) + array = [4096, 2048, 1024, 512, 256, 128, 64, 0] + bsz, bsz2 = 1024, 1024 + for i, k in enumerate(array): + if input_features > array[i + 1]: + bsz = k + break + for i, k in enumerate(array): + if output_features > array[i + 1]: + bsz2 = k + break + + return bsz, bsz2 + + +def matmul_fp8(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1): + if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B) + return MatMulFP8.apply(A, B, out, fw_code, bw_code, bsz, bsz2) + +def matmul_fp8_global(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1): + if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B) + return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2) + +def matmul_fp8_mixed(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1): + if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B) + return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2) + + +def switchback_bnb( + A: tensor, + B: tensor, + out: tensor = None, + state: MatmulLtState = None, + threshold=0.0, + bias=None +): + state = state or MatmulLtState() + if threshold > 0.0: + state.threshold = threshold + return MatMul8bitMixed.apply(A, B, out, bias, state) diff --git a/bitsandbytes/triton/__init__.py b/bitsandbytes/triton/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bitsandbytes/triton/dequantize_rowwise.py b/bitsandbytes/triton/dequantize_rowwise.py new file mode 100644 index 0000000..7e31483 --- /dev/null +++ b/bitsandbytes/triton/dequantize_rowwise.py @@ -0,0 +1,58 @@ +import math +import torch +import time +import triton +import triton.language as tl +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + +# rowwise quantize + +# TODO: autotune this better. +@triton.autotune( + configs=[ + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['n_elements'] +) +@triton.jit +def _dequantize_rowwise( + x_ptr, + state_x, + output_ptr, + inv_127, + n_elements, + BLOCK_SIZE: tl.constexpr, + P2: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + arange = tl.arange(0, P2) + offsets = block_start + arange + row_mask = arange < BLOCK_SIZE + x = tl.load(x_ptr + offsets, mask=row_mask) + max_val = tl.load(state_x + pid) + output = max_val * x * inv_127 + tl.store(output_ptr + offsets, output, mask=row_mask) + + +def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): + output = torch.empty(*x.shape, device=x.device, dtype=torch.float16) + + P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) + + assert x.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (x.shape[0],) + _dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) + return output diff --git a/bitsandbytes/triton/int8_matmul_mixed_dequanitze.py b/bitsandbytes/triton/int8_matmul_mixed_dequanitze.py new file mode 100644 index 0000000..69d4b0c --- /dev/null +++ b/bitsandbytes/triton/int8_matmul_mixed_dequanitze.py @@ -0,0 +1,158 @@ +import torch + +import triton +import triton.language as tl +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + +# This is a matmul kernel based on triton.ops.matmul +# It is modified to support rowwise quantized input and global quantized weight +# It's purpose is fused matmul then dequantize +# It does support bias. + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + +def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) + # split_k + for split_k in [2, 4, 8, 16]: + configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + return configs + + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + # good for int8 + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + ] + get_configs_io_bound(), + key=['M', 'N', 'K'], + prune_configs_by={ + 'early_config_prune': early_config_prune, + 'perf_model': estimate_matmul_time, + 'top_k': 10 + }, +) +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, +}) +@triton.jit +def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr + ): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + w_factor = tl.load(state_w_ptr) + x_factor = tl.load(state_x_ptr + ram)[:, None] + + # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) + acc += tl.dot(a, b) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + + acc = (w_factor * (x_factor * (acc * divfactor))) + acc = acc.to(C.dtype.element_ty) + + # conditionally add bias + if has_bias: + bias = tl.load(bias + rn).to(C.dtype.element_ty) + acc = acc + bias[None, :] + + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): + device = a.device + divfactor = 1. / (127. * 127.) + has_bias = 0 if bias is None else 1 + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # allocates output + c = torch.empty((M, N), device=device, dtype=torch.float16) + # accumulator types + ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + # launch int8_matmul_mixed_dequantize kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) + _int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + GROUP_M=8, ACC_TYPE=ACC_TYPE) + return c diff --git a/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py new file mode 100644 index 0000000..4af054b --- /dev/null +++ b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py @@ -0,0 +1,159 @@ +import torch + +import triton +import triton.language as tl +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + +# This is a matmul kernel based on triton.ops.matmul +# It is modified to support rowwise quantized input and columnwise quantized weight +# It's purpose is fused matmul then dequantize +# It does support bias. + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) + # split_k + for split_k in [2, 4, 8, 16]: + configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + return configs + + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + # good for int8 + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + ] + get_configs_io_bound(), + key=['M', 'N', 'K'], + prune_configs_by={ + 'early_config_prune': early_config_prune, + 'perf_model': estimate_matmul_time, + 'top_k': 10 + }, +) +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, +}) +@triton.jit +def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr + ): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + w_factor = tl.load(state_w_ptr + rbn)[None, :] + x_factor = tl.load(state_x_ptr + ram)[:, None] + + # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) + acc += tl.dot(a, b) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + + acc = (w_factor * (x_factor * (acc * divfactor))) + acc = acc.to(C.dtype.element_ty) + + if has_bias: + bias = tl.load(bias + rn).to(C.dtype.element_ty) + acc = acc + bias[None, :] + + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): + divfactor = 1. / (127. * 127.) + + has_bias = 0 if bias is None else 1 + + device = a.device + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # allocates output + c = torch.empty((M, N), device=device, dtype=torch.float16) + # accumulator types + ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + # launch int8_matmul_rowwise_dequantize kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) + _int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + GROUP_M=8, ACC_TYPE=ACC_TYPE) + return c diff --git a/bitsandbytes/triton/quantize_columnwise_and_transpose.py b/bitsandbytes/triton/quantize_columnwise_and_transpose.py new file mode 100644 index 0000000..4e53475 --- /dev/null +++ b/bitsandbytes/triton/quantize_columnwise_and_transpose.py @@ -0,0 +1,68 @@ +import math +import torch +import time +import triton +import triton.language as tl +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + +# This kernel does fused columnwise quantization and transpose. + +# TODO: autotune this better. +@triton.autotune( + configs=[ + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_stages=16), + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=16, num_warps=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['n_elements'] +) +@triton.jit +def _quantize_columnwise_and_transpose( + x_ptr, + output_ptr, + output_maxs, + n_elements, + M : tl.constexpr, N : tl.constexpr, + BLOCK_SIZE: tl.constexpr, + P2: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid + p2_arange = tl.arange(0, P2) + p2_arange_mask = p2_arange < M + arange = p2_arange * N + offsets = block_start + arange + x = tl.load(x_ptr + offsets, mask=p2_arange_mask) + abs_x = tl.abs(x) + max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0) + output = tl.libdevice.llrint(127. * (x / max_val)) + + new_start = pid * M + new_offsets = new_start + p2_arange + tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask) + tl.store(output_maxs + pid, max_val) + +def quantize_columnwise_and_transpose(x: torch.Tensor): + M, N = x.shape + output = torch.empty(N, M, device=x.device, dtype=torch.int8) + output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16) + + P2 = int(2 ** (math.ceil(math.log2(M)))) + + assert x.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + _quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2) + return output, output_maxs + diff --git a/bitsandbytes/triton/quantize_global.py b/bitsandbytes/triton/quantize_global.py new file mode 100644 index 0000000..229721c --- /dev/null +++ b/bitsandbytes/triton/quantize_global.py @@ -0,0 +1,100 @@ +import math +import torch +import time +import triton +import triton.language as tl +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + +# global quantize +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4), + triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1), + + ], + key=['n_elements'] +) +@triton.jit +def _quantize_global( + x_ptr, + absmax_inv_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + absmax_inv = tl.load(absmax_inv_ptr) + output = tl.libdevice.llrint(127. * (x * absmax_inv)) + tl.store(output_ptr + offsets, output, mask=mask) + +def quantize_global(x: torch.Tensor): + absmax = x.abs().max().unsqueeze(0) + absmax_inv = 1./ absmax + output = torch.empty(*x.shape, device='cuda', dtype=torch.int8) + assert x.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + _quantize_global[grid](x, absmax_inv, output, n_elements) + return output, absmax + + +# global quantize and transpose +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), + + # ... + ], + key=['M', 'N'] +) +@triton.jit +def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N, + BLOCK_M : tl.constexpr, + BLOCK_N : tl.constexpr, + GROUP_M : tl.constexpr): + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // group_size + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an) + mask = (rm < M)[:, None] & (rn < N)[None, :] + a = tl.load(A, mask=mask) + absmax_inv = tl.load(absmax_inv_ptr) + + # rematerialize to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + + output = tl.libdevice.llrint(127. * (a * absmax_inv)) + + tl.store(B, output, mask=mask) + +def quantize_global_transpose(input): + absmax = input.abs().max().unsqueeze(0) + absmax_inv = 1./ absmax + M, N = input.shape + out = torch.empty(N, M, device='cuda', dtype=torch.int8) + + assert out.size(0) == N and out.size(1) == M + assert input.stride(0) == 1 or input.stride(1) == 1 + assert out.stride(0) == 1 or out.stride(1) == 1 + + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),) + _quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N) + return out, absmax + diff --git a/bitsandbytes/triton/quantize_rowwise.py b/bitsandbytes/triton/quantize_rowwise.py new file mode 100644 index 0000000..d956647 --- /dev/null +++ b/bitsandbytes/triton/quantize_rowwise.py @@ -0,0 +1,61 @@ +import math +import torch +import time +import triton +import triton.language as tl +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + +# rowwise quantize + +# TODO: autotune this better. +@triton.autotune( + configs=[ + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['n_elements'] +) +@triton.jit +def _quantize_rowwise( + x_ptr, + output_ptr, + output_maxs, + n_elements, + BLOCK_SIZE: tl.constexpr, + P2: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + arange = tl.arange(0, P2) + offsets = block_start + arange + row_mask = arange < BLOCK_SIZE + x = tl.load(x_ptr + offsets, mask=row_mask) + + abs_x = tl.abs(x) + max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) + output = tl.libdevice.llrint(127. * (x / max_val)) + tl.store(output_ptr + offsets, output, mask=row_mask) + tl.store(output_maxs + pid, max_val) + +def quantize_rowwise(x: torch.Tensor): + output = torch.empty(*x.shape, device=x.device, dtype=torch.int8) + output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16) + + P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) + + assert x.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (x.shape[0],) + _quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) + return output, output_maxs +