Added is_available_triton guard.
This commit is contained in:
parent
7140c01405
commit
c3d87e4435
|
@ -184,7 +184,7 @@ class MatMulFP8Global(torch.autograd.Function):
|
|||
return grad_A, grad_B, None, None, None, None, None
|
||||
|
||||
|
||||
class MatMul8bitMixed(torch.autograd.Function):
|
||||
class SwitchBackBnb(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
|
||||
# default to pytorch behavior if inputs are empty
|
||||
|
@ -408,4 +408,4 @@ def switchback_bnb(
|
|||
state = state or MatmulLtState()
|
||||
if threshold > 0.0:
|
||||
state.threshold = threshold
|
||||
return MatMul8bitMixed.apply(A, B, out, bias, state)
|
||||
return SwitchBackBnb.apply(A, B, out, bias, state)
|
||||
|
|
|
@ -1,58 +1,64 @@
|
|||
import math
|
||||
import torch
|
||||
import time
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
# rowwise quantize
|
||||
if not is_triton_available():
|
||||
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None
|
||||
else:
|
||||
|
||||
# TODO: autotune this better.
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_stages=1, num_warps=8),
|
||||
triton.Config({}, num_stages=2, num_warps=8),
|
||||
triton.Config({}, num_stages=4, num_warps=8),
|
||||
triton.Config({}, num_stages=8, num_warps=8),
|
||||
triton.Config({}, num_stages=1),
|
||||
triton.Config({}, num_stages=2),
|
||||
triton.Config({}, num_stages=4),
|
||||
triton.Config({}, num_stages=8),
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
],
|
||||
key=['n_elements']
|
||||
)
|
||||
@triton.jit
|
||||
def _dequantize_rowwise(
|
||||
x_ptr,
|
||||
state_x,
|
||||
output_ptr,
|
||||
inv_127,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
P2: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
arange = tl.arange(0, P2)
|
||||
offsets = block_start + arange
|
||||
row_mask = arange < BLOCK_SIZE
|
||||
x = tl.load(x_ptr + offsets, mask=row_mask)
|
||||
max_val = tl.load(state_x + pid)
|
||||
output = max_val * x * inv_127
|
||||
tl.store(output_ptr + offsets, output, mask=row_mask)
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
|
||||
output = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
|
||||
# rowwise quantize
|
||||
|
||||
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
|
||||
# TODO: autotune this better.
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_stages=1, num_warps=8),
|
||||
triton.Config({}, num_stages=2, num_warps=8),
|
||||
triton.Config({}, num_stages=4, num_warps=8),
|
||||
triton.Config({}, num_stages=8, num_warps=8),
|
||||
triton.Config({}, num_stages=1),
|
||||
triton.Config({}, num_stages=2),
|
||||
triton.Config({}, num_stages=4),
|
||||
triton.Config({}, num_stages=8),
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
],
|
||||
key=['n_elements']
|
||||
)
|
||||
@triton.jit
|
||||
def _dequantize_rowwise(
|
||||
x_ptr,
|
||||
state_x,
|
||||
output_ptr,
|
||||
inv_127,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
P2: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
arange = tl.arange(0, P2)
|
||||
offsets = block_start + arange
|
||||
row_mask = arange < BLOCK_SIZE
|
||||
x = tl.load(x_ptr + offsets, mask=row_mask)
|
||||
max_val = tl.load(state_x + pid)
|
||||
output = max_val * x * inv_127
|
||||
tl.store(output_ptr + offsets, output, mask=row_mask)
|
||||
|
||||
|
||||
assert x.is_cuda and output.is_cuda
|
||||
n_elements = output.numel()
|
||||
grid = lambda meta: (x.shape[0],)
|
||||
_dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
|
||||
return output
|
||||
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
|
||||
output = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
|
||||
|
||||
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
|
||||
|
||||
assert x.is_cuda and output.is_cuda
|
||||
n_elements = output.numel()
|
||||
grid = lambda meta: (x.shape[0],)
|
||||
_dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
|
||||
return output
|
||||
|
|
|
@ -1,158 +1,163 @@
|
|||
import torch
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
if not is_triton_available():
|
||||
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): return None
|
||||
else:
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
|
||||
# This is a matmul kernel based on triton.ops.matmul
|
||||
# It is modified to support rowwise quantized input and global quantized weight
|
||||
# It's purpose is fused matmul then dequantize
|
||||
# It does support bias.
|
||||
# This is a matmul kernel based on triton.ops.matmul
|
||||
# It is modified to support rowwise quantized input and global quantized weight
|
||||
# It's purpose is fused matmul then dequantize
|
||||
# It does support bias.
|
||||
|
||||
def init_to_zero(name):
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
def init_to_zero(name):
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
|
||||
def get_configs_io_bound():
|
||||
configs = []
|
||||
for num_stages in [2, 3, 4, 5, 6]:
|
||||
for block_m in [16, 32]:
|
||||
for block_k in [32, 64]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
num_warps = 2 if block_n <= 64 else 4
|
||||
configs.append(
|
||||
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||
num_stages=num_stages, num_warps=num_warps))
|
||||
# split_k
|
||||
for split_k in [2, 4, 8, 16]:
|
||||
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
return configs
|
||||
def get_configs_io_bound():
|
||||
configs = []
|
||||
for num_stages in [2, 3, 4, 5, 6]:
|
||||
for block_m in [16, 32]:
|
||||
for block_k in [32, 64]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
num_warps = 2 if block_n <= 64 else 4
|
||||
configs.append(
|
||||
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||
num_stages=num_stages, num_warps=num_warps))
|
||||
# split_k
|
||||
for split_k in [2, 4, 8, 16]:
|
||||
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
return configs
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
# basic configs for compute-bound matmuls
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
# good for int8
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
] + get_configs_io_bound(),
|
||||
key=['M', 'N', 'K'],
|
||||
prune_configs_by={
|
||||
'early_config_prune': early_config_prune,
|
||||
'perf_model': estimate_matmul_time,
|
||||
'top_k': 10
|
||||
},
|
||||
)
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr
|
||||
):
|
||||
# matrix multiplication
|
||||
pid = tl.program_id(0)
|
||||
pid_z = tl.program_id(1)
|
||||
grid_m = tl.cdiv(M, BLOCK_M)
|
||||
grid_n = tl.cdiv(N, BLOCK_N)
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
# do matrix multiplication
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
# pointers
|
||||
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
# basic configs for compute-bound matmuls
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
# good for int8
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
] + get_configs_io_bound(),
|
||||
key=['M', 'N', 'K'],
|
||||
prune_configs_by={
|
||||
'early_config_prune': early_config_prune,
|
||||
'perf_model': estimate_matmul_time,
|
||||
'top_k': 10
|
||||
},
|
||||
)
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr
|
||||
):
|
||||
# matrix multiplication
|
||||
pid = tl.program_id(0)
|
||||
pid_z = tl.program_id(1)
|
||||
grid_m = tl.cdiv(M, BLOCK_M)
|
||||
grid_n = tl.cdiv(N, BLOCK_N)
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
# do matrix multiplication
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
# pointers
|
||||
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
w_factor = tl.load(state_w_ptr)
|
||||
x_factor = tl.load(state_x_ptr + ram)[:, None]
|
||||
w_factor = tl.load(state_w_ptr)
|
||||
x_factor = tl.load(state_x_ptr + ram)[:, None]
|
||||
|
||||
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||
if EVEN_K:
|
||||
a = tl.load(A)
|
||||
b = tl.load(B)
|
||||
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||
if EVEN_K:
|
||||
a = tl.load(A)
|
||||
b = tl.load(B)
|
||||
else:
|
||||
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
||||
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
|
||||
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
|
||||
acc += tl.dot(a, b)
|
||||
A += BLOCK_K * SPLIT_K * stride_ak
|
||||
B += BLOCK_K * SPLIT_K * stride_bk
|
||||
|
||||
acc = (w_factor * (x_factor * (acc * divfactor)))
|
||||
acc = acc.to(C.dtype.element_ty)
|
||||
|
||||
# conditionally add bias
|
||||
if has_bias:
|
||||
bias = tl.load(bias + rn).to(C.dtype.element_ty)
|
||||
acc = acc + bias[None, :]
|
||||
|
||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
# handles write-back with reduction-splitting
|
||||
if SPLIT_K == 1:
|
||||
tl.store(C, acc, mask=mask)
|
||||
else:
|
||||
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
||||
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
|
||||
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
|
||||
acc += tl.dot(a, b)
|
||||
A += BLOCK_K * SPLIT_K * stride_ak
|
||||
B += BLOCK_K * SPLIT_K * stride_bk
|
||||
|
||||
acc = (w_factor * (x_factor * (acc * divfactor)))
|
||||
acc = acc.to(C.dtype.element_ty)
|
||||
|
||||
# conditionally add bias
|
||||
if has_bias:
|
||||
bias = tl.load(bias + rn).to(C.dtype.element_ty)
|
||||
acc = acc + bias[None, :]
|
||||
|
||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
# handles write-back with reduction-splitting
|
||||
if SPLIT_K == 1:
|
||||
tl.store(C, acc, mask=mask)
|
||||
else:
|
||||
tl.atomic_add(C, acc, mask=mask)
|
||||
tl.atomic_add(C, acc, mask=mask)
|
||||
|
||||
|
||||
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias):
|
||||
device = a.device
|
||||
divfactor = 1. / (127. * 127.)
|
||||
has_bias = 0 if bias is None else 1
|
||||
# handle non-contiguous inputs if necessary
|
||||
if a.stride(0) > 1 and a.stride(1) > 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(0) > 1 and b.stride(1) > 1:
|
||||
b = b.contiguous()
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
M, K = a.shape
|
||||
_, N = b.shape
|
||||
# allocates output
|
||||
c = torch.empty((M, N), device=device, dtype=torch.float16)
|
||||
# accumulator types
|
||||
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
|
||||
# launch int8_matmul_mixed_dequantize kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||
_int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
GROUP_M=8, ACC_TYPE=ACC_TYPE)
|
||||
return c
|
||||
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias):
|
||||
device = a.device
|
||||
divfactor = 1. / (127. * 127.)
|
||||
has_bias = 0 if bias is None else 1
|
||||
# handle non-contiguous inputs if necessary
|
||||
if a.stride(0) > 1 and a.stride(1) > 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(0) > 1 and b.stride(1) > 1:
|
||||
b = b.contiguous()
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
M, K = a.shape
|
||||
_, N = b.shape
|
||||
# allocates output
|
||||
c = torch.empty((M, N), device=device, dtype=torch.float16)
|
||||
# accumulator types
|
||||
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
|
||||
# launch int8_matmul_mixed_dequantize kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||
_int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
GROUP_M=8, ACC_TYPE=ACC_TYPE)
|
||||
return c
|
||||
|
|
|
@ -1,159 +1,164 @@
|
|||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
# This is a matmul kernel based on triton.ops.matmul
|
||||
# It is modified to support rowwise quantized input and columnwise quantized weight
|
||||
# It's purpose is fused matmul then dequantize
|
||||
# It does support bias.
|
||||
if not is_triton_available():
|
||||
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None
|
||||
else:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
def init_to_zero(name):
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
# This is a matmul kernel based on triton.ops.matmul
|
||||
# It is modified to support rowwise quantized input and columnwise quantized weight
|
||||
# It's purpose is fused matmul then dequantize
|
||||
# It does support bias.
|
||||
|
||||
def init_to_zero(name):
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
|
||||
|
||||
def get_configs_io_bound():
|
||||
configs = []
|
||||
for num_stages in [2, 3, 4, 5, 6]:
|
||||
for block_m in [16, 32]:
|
||||
for block_k in [32, 64]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
num_warps = 2 if block_n <= 64 else 4
|
||||
configs.append(
|
||||
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||
num_stages=num_stages, num_warps=num_warps))
|
||||
# split_k
|
||||
for split_k in [2, 4, 8, 16]:
|
||||
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
return configs
|
||||
def get_configs_io_bound():
|
||||
configs = []
|
||||
for num_stages in [2, 3, 4, 5, 6]:
|
||||
for block_m in [16, 32]:
|
||||
for block_k in [32, 64]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
num_warps = 2 if block_n <= 64 else 4
|
||||
configs.append(
|
||||
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||
num_stages=num_stages, num_warps=num_warps))
|
||||
# split_k
|
||||
for split_k in [2, 4, 8, 16]:
|
||||
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
return configs
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
# basic configs for compute-bound matmuls
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
# good for int8
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
] + get_configs_io_bound(),
|
||||
key=['M', 'N', 'K'],
|
||||
prune_configs_by={
|
||||
'early_config_prune': early_config_prune,
|
||||
'perf_model': estimate_matmul_time,
|
||||
'top_k': 10
|
||||
},
|
||||
)
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr
|
||||
):
|
||||
# matrix multiplication
|
||||
pid = tl.program_id(0)
|
||||
pid_z = tl.program_id(1)
|
||||
grid_m = tl.cdiv(M, BLOCK_M)
|
||||
grid_n = tl.cdiv(N, BLOCK_N)
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
# do matrix multiplication
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
# pointers
|
||||
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
# basic configs for compute-bound matmuls
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
# good for int8
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
] + get_configs_io_bound(),
|
||||
key=['M', 'N', 'K'],
|
||||
prune_configs_by={
|
||||
'early_config_prune': early_config_prune,
|
||||
'perf_model': estimate_matmul_time,
|
||||
'top_k': 10
|
||||
},
|
||||
)
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr
|
||||
):
|
||||
# matrix multiplication
|
||||
pid = tl.program_id(0)
|
||||
pid_z = tl.program_id(1)
|
||||
grid_m = tl.cdiv(M, BLOCK_M)
|
||||
grid_n = tl.cdiv(N, BLOCK_N)
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
# do matrix multiplication
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
# pointers
|
||||
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
w_factor = tl.load(state_w_ptr + rbn)[None, :]
|
||||
x_factor = tl.load(state_x_ptr + ram)[:, None]
|
||||
w_factor = tl.load(state_w_ptr + rbn)[None, :]
|
||||
x_factor = tl.load(state_x_ptr + ram)[:, None]
|
||||
|
||||
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||
if EVEN_K:
|
||||
a = tl.load(A)
|
||||
b = tl.load(B)
|
||||
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||
if EVEN_K:
|
||||
a = tl.load(A)
|
||||
b = tl.load(B)
|
||||
else:
|
||||
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
||||
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
|
||||
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
|
||||
acc += tl.dot(a, b)
|
||||
A += BLOCK_K * SPLIT_K * stride_ak
|
||||
B += BLOCK_K * SPLIT_K * stride_bk
|
||||
|
||||
acc = (w_factor * (x_factor * (acc * divfactor)))
|
||||
acc = acc.to(C.dtype.element_ty)
|
||||
|
||||
if has_bias:
|
||||
bias = tl.load(bias + rn).to(C.dtype.element_ty)
|
||||
acc = acc + bias[None, :]
|
||||
|
||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
# handles write-back with reduction-splitting
|
||||
if SPLIT_K == 1:
|
||||
tl.store(C, acc, mask=mask)
|
||||
else:
|
||||
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
||||
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
|
||||
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
|
||||
acc += tl.dot(a, b)
|
||||
A += BLOCK_K * SPLIT_K * stride_ak
|
||||
B += BLOCK_K * SPLIT_K * stride_bk
|
||||
|
||||
acc = (w_factor * (x_factor * (acc * divfactor)))
|
||||
acc = acc.to(C.dtype.element_ty)
|
||||
|
||||
if has_bias:
|
||||
bias = tl.load(bias + rn).to(C.dtype.element_ty)
|
||||
acc = acc + bias[None, :]
|
||||
|
||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
# handles write-back with reduction-splitting
|
||||
if SPLIT_K == 1:
|
||||
tl.store(C, acc, mask=mask)
|
||||
else:
|
||||
tl.atomic_add(C, acc, mask=mask)
|
||||
tl.atomic_add(C, acc, mask=mask)
|
||||
|
||||
|
||||
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
|
||||
divfactor = 1. / (127. * 127.)
|
||||
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
|
||||
divfactor = 1. / (127. * 127.)
|
||||
|
||||
has_bias = 0 if bias is None else 1
|
||||
has_bias = 0 if bias is None else 1
|
||||
|
||||
device = a.device
|
||||
# handle non-contiguous inputs if necessary
|
||||
if a.stride(0) > 1 and a.stride(1) > 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(0) > 1 and b.stride(1) > 1:
|
||||
b = b.contiguous()
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
M, K = a.shape
|
||||
_, N = b.shape
|
||||
# allocates output
|
||||
c = torch.empty((M, N), device=device, dtype=torch.float16)
|
||||
# accumulator types
|
||||
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
|
||||
# launch int8_matmul_rowwise_dequantize kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||
_int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
GROUP_M=8, ACC_TYPE=ACC_TYPE)
|
||||
return c
|
||||
device = a.device
|
||||
# handle non-contiguous inputs if necessary
|
||||
if a.stride(0) > 1 and a.stride(1) > 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(0) > 1 and b.stride(1) > 1:
|
||||
b = b.contiguous()
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
M, K = a.shape
|
||||
_, N = b.shape
|
||||
# allocates output
|
||||
c = torch.empty((M, N), device=device, dtype=torch.float16)
|
||||
# accumulator types
|
||||
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
|
||||
# launch int8_matmul_rowwise_dequantize kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||
_int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
GROUP_M=8, ACC_TYPE=ACC_TYPE)
|
||||
return c
|
||||
|
|
|
@ -1,68 +1,74 @@
|
|||
import math
|
||||
import torch
|
||||
import time
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
# This kernel does fused columnwise quantization and transpose.
|
||||
if not is_triton_available():
|
||||
def quantize_columnwise_and_transpose(x: torch.Tensor): return None
|
||||
else:
|
||||
|
||||
# TODO: autotune this better.
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_stages=1),
|
||||
triton.Config({}, num_stages=2),
|
||||
triton.Config({}, num_stages=4),
|
||||
triton.Config({}, num_stages=8),
|
||||
triton.Config({}, num_stages=16),
|
||||
triton.Config({}, num_stages=1, num_warps=8),
|
||||
triton.Config({}, num_stages=2, num_warps=8),
|
||||
triton.Config({}, num_stages=4, num_warps=8),
|
||||
triton.Config({}, num_stages=8, num_warps=8),
|
||||
triton.Config({}, num_stages=16, num_warps=8),
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
],
|
||||
key=['n_elements']
|
||||
)
|
||||
@triton.jit
|
||||
def _quantize_columnwise_and_transpose(
|
||||
x_ptr,
|
||||
output_ptr,
|
||||
output_maxs,
|
||||
n_elements,
|
||||
M : tl.constexpr, N : tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
P2: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid
|
||||
p2_arange = tl.arange(0, P2)
|
||||
p2_arange_mask = p2_arange < M
|
||||
arange = p2_arange * N
|
||||
offsets = block_start + arange
|
||||
x = tl.load(x_ptr + offsets, mask=p2_arange_mask)
|
||||
abs_x = tl.abs(x)
|
||||
max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0)
|
||||
output = tl.libdevice.llrint(127. * (x / max_val))
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
new_start = pid * M
|
||||
new_offsets = new_start + p2_arange
|
||||
tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask)
|
||||
tl.store(output_maxs + pid, max_val)
|
||||
# This kernel does fused columnwise quantization and transpose.
|
||||
|
||||
def quantize_columnwise_and_transpose(x: torch.Tensor):
|
||||
M, N = x.shape
|
||||
output = torch.empty(N, M, device=x.device, dtype=torch.int8)
|
||||
output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16)
|
||||
# TODO: autotune this better.
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_stages=1),
|
||||
triton.Config({}, num_stages=2),
|
||||
triton.Config({}, num_stages=4),
|
||||
triton.Config({}, num_stages=8),
|
||||
triton.Config({}, num_stages=16),
|
||||
triton.Config({}, num_stages=1, num_warps=8),
|
||||
triton.Config({}, num_stages=2, num_warps=8),
|
||||
triton.Config({}, num_stages=4, num_warps=8),
|
||||
triton.Config({}, num_stages=8, num_warps=8),
|
||||
triton.Config({}, num_stages=16, num_warps=8),
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
],
|
||||
key=['n_elements']
|
||||
)
|
||||
@triton.jit
|
||||
def _quantize_columnwise_and_transpose(
|
||||
x_ptr,
|
||||
output_ptr,
|
||||
output_maxs,
|
||||
n_elements,
|
||||
M : tl.constexpr, N : tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
P2: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid
|
||||
p2_arange = tl.arange(0, P2)
|
||||
p2_arange_mask = p2_arange < M
|
||||
arange = p2_arange * N
|
||||
offsets = block_start + arange
|
||||
x = tl.load(x_ptr + offsets, mask=p2_arange_mask)
|
||||
abs_x = tl.abs(x)
|
||||
max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0)
|
||||
output = tl.libdevice.llrint(127. * (x / max_val))
|
||||
|
||||
P2 = int(2 ** (math.ceil(math.log2(M))))
|
||||
new_start = pid * M
|
||||
new_offsets = new_start + p2_arange
|
||||
tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask)
|
||||
tl.store(output_maxs + pid, max_val)
|
||||
|
||||
assert x.is_cuda and output.is_cuda
|
||||
n_elements = output.numel()
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
_quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2)
|
||||
return output, output_maxs
|
||||
def quantize_columnwise_and_transpose(x: torch.Tensor):
|
||||
M, N = x.shape
|
||||
output = torch.empty(N, M, device=x.device, dtype=torch.int8)
|
||||
output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16)
|
||||
|
||||
P2 = int(2 ** (math.ceil(math.log2(M))))
|
||||
|
||||
assert x.is_cuda and output.is_cuda
|
||||
n_elements = output.numel()
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
_quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2)
|
||||
return output, output_maxs
|
||||
|
||||
|
|
|
@ -1,100 +1,107 @@
|
|||
import math
|
||||
import torch
|
||||
import time
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
# global quantize
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1),
|
||||
if not is_triton_available():
|
||||
def quantize_global_transpose(input): return None
|
||||
def quantize_global(x: torch.Tensor): return None
|
||||
else:
|
||||
|
||||
],
|
||||
key=['n_elements']
|
||||
)
|
||||
@triton.jit
|
||||
def _quantize_global(
|
||||
x_ptr,
|
||||
absmax_inv_ptr,
|
||||
output_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
absmax_inv = tl.load(absmax_inv_ptr)
|
||||
output = tl.libdevice.llrint(127. * (x * absmax_inv))
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
def quantize_global(x: torch.Tensor):
|
||||
absmax = x.abs().max().unsqueeze(0)
|
||||
absmax_inv = 1./ absmax
|
||||
output = torch.empty(*x.shape, device='cuda', dtype=torch.int8)
|
||||
assert x.is_cuda and output.is_cuda
|
||||
n_elements = output.numel()
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
_quantize_global[grid](x, absmax_inv, output, n_elements)
|
||||
return output, absmax
|
||||
# global quantize
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1),
|
||||
|
||||
],
|
||||
key=['n_elements']
|
||||
)
|
||||
@triton.jit
|
||||
def _quantize_global(
|
||||
x_ptr,
|
||||
absmax_inv_ptr,
|
||||
output_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
absmax_inv = tl.load(absmax_inv_ptr)
|
||||
output = tl.libdevice.llrint(127. * (x * absmax_inv))
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
def quantize_global(x: torch.Tensor):
|
||||
absmax = x.abs().max().unsqueeze(0)
|
||||
absmax_inv = 1./ absmax
|
||||
output = torch.empty(*x.shape, device='cuda', dtype=torch.int8)
|
||||
assert x.is_cuda and output.is_cuda
|
||||
n_elements = output.numel()
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
_quantize_global[grid](x, absmax_inv, output, n_elements)
|
||||
return output, absmax
|
||||
|
||||
|
||||
# global quantize and transpose
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
|
||||
# global quantize and transpose
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
|
||||
|
||||
# ...
|
||||
],
|
||||
key=['M', 'N']
|
||||
)
|
||||
@triton.jit
|
||||
def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N,
|
||||
BLOCK_M : tl.constexpr,
|
||||
BLOCK_N : tl.constexpr,
|
||||
GROUP_M : tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // group_size
|
||||
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
a = tl.load(A, mask=mask)
|
||||
absmax_inv = tl.load(absmax_inv_ptr)
|
||||
|
||||
# rematerialize to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
# ...
|
||||
],
|
||||
key=['M', 'N']
|
||||
)
|
||||
@triton.jit
|
||||
def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N,
|
||||
BLOCK_M : tl.constexpr,
|
||||
BLOCK_N : tl.constexpr,
|
||||
GROUP_M : tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // group_size
|
||||
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
a = tl.load(A, mask=mask)
|
||||
absmax_inv = tl.load(absmax_inv_ptr)
|
||||
|
||||
# rematerialize to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
|
||||
output = tl.libdevice.llrint(127. * (a * absmax_inv))
|
||||
output = tl.libdevice.llrint(127. * (a * absmax_inv))
|
||||
|
||||
tl.store(B, output, mask=mask)
|
||||
tl.store(B, output, mask=mask)
|
||||
|
||||
def quantize_global_transpose(input):
|
||||
absmax = input.abs().max().unsqueeze(0)
|
||||
absmax_inv = 1./ absmax
|
||||
M, N = input.shape
|
||||
out = torch.empty(N, M, device='cuda', dtype=torch.int8)
|
||||
|
||||
assert out.size(0) == N and out.size(1) == M
|
||||
assert input.stride(0) == 1 or input.stride(1) == 1
|
||||
assert out.stride(0) == 1 or out.stride(1) == 1
|
||||
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
|
||||
_quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N)
|
||||
return out, absmax
|
||||
def quantize_global_transpose(input):
|
||||
absmax = input.abs().max().unsqueeze(0)
|
||||
absmax_inv = 1./ absmax
|
||||
M, N = input.shape
|
||||
out = torch.empty(N, M, device='cuda', dtype=torch.int8)
|
||||
|
||||
assert out.size(0) == N and out.size(1) == M
|
||||
assert input.stride(0) == 1 or input.stride(1) == 1
|
||||
assert out.stride(0) == 1 or out.stride(1) == 1
|
||||
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
|
||||
_quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N)
|
||||
return out, absmax
|
||||
|
||||
|
|
|
@ -1,61 +1,68 @@
|
|||
import math
|
||||
import torch
|
||||
import time
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
# rowwise quantize
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
# TODO: autotune this better.
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_stages=1, num_warps=8),
|
||||
triton.Config({}, num_stages=2, num_warps=8),
|
||||
triton.Config({}, num_stages=4, num_warps=8),
|
||||
triton.Config({}, num_stages=8, num_warps=8),
|
||||
triton.Config({}, num_stages=1),
|
||||
triton.Config({}, num_stages=2),
|
||||
triton.Config({}, num_stages=4),
|
||||
triton.Config({}, num_stages=8),
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
],
|
||||
key=['n_elements']
|
||||
)
|
||||
@triton.jit
|
||||
def _quantize_rowwise(
|
||||
x_ptr,
|
||||
output_ptr,
|
||||
output_maxs,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
P2: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
arange = tl.arange(0, P2)
|
||||
offsets = block_start + arange
|
||||
row_mask = arange < BLOCK_SIZE
|
||||
x = tl.load(x_ptr + offsets, mask=row_mask)
|
||||
|
||||
abs_x = tl.abs(x)
|
||||
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
|
||||
output = tl.libdevice.llrint(127. * (x / max_val))
|
||||
tl.store(output_ptr + offsets, output, mask=row_mask)
|
||||
tl.store(output_maxs + pid, max_val)
|
||||
if not is_triton_available():
|
||||
def quantize_rowwise(x: torch.Tensor): return None
|
||||
else:
|
||||
|
||||
def quantize_rowwise(x: torch.Tensor):
|
||||
output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)
|
||||
output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
|
||||
# rowwise quantize
|
||||
|
||||
assert x.is_cuda and output.is_cuda
|
||||
n_elements = output.numel()
|
||||
grid = lambda meta: (x.shape[0],)
|
||||
_quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
|
||||
return output, output_maxs
|
||||
# TODO: autotune this better.
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_stages=1, num_warps=8),
|
||||
triton.Config({}, num_stages=2, num_warps=8),
|
||||
triton.Config({}, num_stages=4, num_warps=8),
|
||||
triton.Config({}, num_stages=8, num_warps=8),
|
||||
triton.Config({}, num_stages=1),
|
||||
triton.Config({}, num_stages=2),
|
||||
triton.Config({}, num_stages=4),
|
||||
triton.Config({}, num_stages=8),
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
],
|
||||
key=['n_elements']
|
||||
)
|
||||
@triton.jit
|
||||
def _quantize_rowwise(
|
||||
x_ptr,
|
||||
output_ptr,
|
||||
output_maxs,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
P2: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
arange = tl.arange(0, P2)
|
||||
offsets = block_start + arange
|
||||
row_mask = arange < BLOCK_SIZE
|
||||
x = tl.load(x_ptr + offsets, mask=row_mask)
|
||||
|
||||
abs_x = tl.abs(x)
|
||||
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
|
||||
output = tl.libdevice.llrint(127. * (x / max_val))
|
||||
tl.store(output_ptr + offsets, output, mask=row_mask)
|
||||
tl.store(output_maxs + pid, max_val)
|
||||
|
||||
def quantize_rowwise(x: torch.Tensor):
|
||||
output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)
|
||||
output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)
|
||||
|
||||
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
|
||||
|
||||
assert x.is_cuda and output.is_cuda
|
||||
n_elements = output.numel()
|
||||
grid = lambda meta: (x.shape[0],)
|
||||
_quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
|
||||
return output, output_maxs
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user