Added is_available_triton guard.
This commit is contained in:
parent
7140c01405
commit
c3d87e4435
|
@ -184,7 +184,7 @@ class MatMulFP8Global(torch.autograd.Function):
|
||||||
return grad_A, grad_B, None, None, None, None, None
|
return grad_A, grad_B, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
class MatMul8bitMixed(torch.autograd.Function):
|
class SwitchBackBnb(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
|
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
|
||||||
# default to pytorch behavior if inputs are empty
|
# default to pytorch behavior if inputs are empty
|
||||||
|
@ -408,4 +408,4 @@ def switchback_bnb(
|
||||||
state = state or MatmulLtState()
|
state = state or MatmulLtState()
|
||||||
if threshold > 0.0:
|
if threshold > 0.0:
|
||||||
state.threshold = threshold
|
state.threshold = threshold
|
||||||
return MatMul8bitMixed.apply(A, B, out, bias, state)
|
return SwitchBackBnb.apply(A, B, out, bias, state)
|
||||||
|
|
|
@ -1,14 +1,20 @@
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
import triton
|
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||||
import triton.language as tl
|
|
||||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
|
||||||
|
|
||||||
# rowwise quantize
|
if not is_triton_available():
|
||||||
|
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None
|
||||||
|
else:
|
||||||
|
|
||||||
# TODO: autotune this better.
|
import triton
|
||||||
@triton.autotune(
|
import triton.language as tl
|
||||||
|
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||||
|
|
||||||
|
# rowwise quantize
|
||||||
|
|
||||||
|
# TODO: autotune this better.
|
||||||
|
@triton.autotune(
|
||||||
configs=[
|
configs=[
|
||||||
triton.Config({}, num_stages=1, num_warps=8),
|
triton.Config({}, num_stages=1, num_warps=8),
|
||||||
triton.Config({}, num_stages=2, num_warps=8),
|
triton.Config({}, num_stages=2, num_warps=8),
|
||||||
|
@ -24,9 +30,9 @@ from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_tim
|
||||||
triton.Config({}, num_warps=8),
|
triton.Config({}, num_warps=8),
|
||||||
],
|
],
|
||||||
key=['n_elements']
|
key=['n_elements']
|
||||||
)
|
)
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _dequantize_rowwise(
|
def _dequantize_rowwise(
|
||||||
x_ptr,
|
x_ptr,
|
||||||
state_x,
|
state_x,
|
||||||
output_ptr,
|
output_ptr,
|
||||||
|
@ -34,7 +40,7 @@ def _dequantize_rowwise(
|
||||||
n_elements,
|
n_elements,
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
P2: tl.constexpr,
|
P2: tl.constexpr,
|
||||||
):
|
):
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
block_start = pid * BLOCK_SIZE
|
block_start = pid * BLOCK_SIZE
|
||||||
arange = tl.arange(0, P2)
|
arange = tl.arange(0, P2)
|
||||||
|
@ -46,7 +52,7 @@ def _dequantize_rowwise(
|
||||||
tl.store(output_ptr + offsets, output, mask=row_mask)
|
tl.store(output_ptr + offsets, output, mask=row_mask)
|
||||||
|
|
||||||
|
|
||||||
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
|
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
|
||||||
output = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
|
output = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
|
||||||
|
|
||||||
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
|
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
|
||||||
|
|
|
@ -1,19 +1,24 @@
|
||||||
import torch
|
import torch
|
||||||
|
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||||
|
|
||||||
import triton
|
if not is_triton_available():
|
||||||
import triton.language as tl
|
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): return None
|
||||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
else:
|
||||||
|
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||||
|
|
||||||
|
|
||||||
# This is a matmul kernel based on triton.ops.matmul
|
# This is a matmul kernel based on triton.ops.matmul
|
||||||
# It is modified to support rowwise quantized input and global quantized weight
|
# It is modified to support rowwise quantized input and global quantized weight
|
||||||
# It's purpose is fused matmul then dequantize
|
# It's purpose is fused matmul then dequantize
|
||||||
# It does support bias.
|
# It does support bias.
|
||||||
|
|
||||||
def init_to_zero(name):
|
def init_to_zero(name):
|
||||||
return lambda nargs: nargs[name].zero_()
|
return lambda nargs: nargs[name].zero_()
|
||||||
|
|
||||||
def get_configs_io_bound():
|
def get_configs_io_bound():
|
||||||
configs = []
|
configs = []
|
||||||
for num_stages in [2, 3, 4, 5, 6]:
|
for num_stages in [2, 3, 4, 5, 6]:
|
||||||
for block_m in [16, 32]:
|
for block_m in [16, 32]:
|
||||||
|
@ -30,7 +35,7 @@ def get_configs_io_bound():
|
||||||
return configs
|
return configs
|
||||||
|
|
||||||
|
|
||||||
@triton.autotune(
|
@triton.autotune(
|
||||||
configs=[
|
configs=[
|
||||||
# basic configs for compute-bound matmuls
|
# basic configs for compute-bound matmuls
|
||||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||||
|
@ -59,12 +64,12 @@ def get_configs_io_bound():
|
||||||
'perf_model': estimate_matmul_time,
|
'perf_model': estimate_matmul_time,
|
||||||
'top_k': 10
|
'top_k': 10
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@triton.heuristics({
|
@triton.heuristics({
|
||||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||||
})
|
})
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr,
|
def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr,
|
||||||
stride_am, stride_ak,
|
stride_am, stride_ak,
|
||||||
stride_bk, stride_bn,
|
stride_bk, stride_bn,
|
||||||
stride_cm, stride_cn,
|
stride_cm, stride_cn,
|
||||||
|
@ -131,7 +136,7 @@ def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N,
|
||||||
tl.atomic_add(C, acc, mask=mask)
|
tl.atomic_add(C, acc, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias):
|
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias):
|
||||||
device = a.device
|
device = a.device
|
||||||
divfactor = 1. / (127. * 127.)
|
divfactor = 1. / (127. * 127.)
|
||||||
has_bias = 0 if bias is None else 1
|
has_bias = 0 if bias is None else 1
|
||||||
|
|
|
@ -1,19 +1,24 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import triton
|
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||||
import triton.language as tl
|
|
||||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
|
||||||
|
|
||||||
# This is a matmul kernel based on triton.ops.matmul
|
if not is_triton_available():
|
||||||
# It is modified to support rowwise quantized input and columnwise quantized weight
|
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None
|
||||||
# It's purpose is fused matmul then dequantize
|
else:
|
||||||
# It does support bias.
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||||
|
|
||||||
def init_to_zero(name):
|
# This is a matmul kernel based on triton.ops.matmul
|
||||||
|
# It is modified to support rowwise quantized input and columnwise quantized weight
|
||||||
|
# It's purpose is fused matmul then dequantize
|
||||||
|
# It does support bias.
|
||||||
|
|
||||||
|
def init_to_zero(name):
|
||||||
return lambda nargs: nargs[name].zero_()
|
return lambda nargs: nargs[name].zero_()
|
||||||
|
|
||||||
|
|
||||||
def get_configs_io_bound():
|
def get_configs_io_bound():
|
||||||
configs = []
|
configs = []
|
||||||
for num_stages in [2, 3, 4, 5, 6]:
|
for num_stages in [2, 3, 4, 5, 6]:
|
||||||
for block_m in [16, 32]:
|
for block_m in [16, 32]:
|
||||||
|
@ -30,7 +35,7 @@ def get_configs_io_bound():
|
||||||
return configs
|
return configs
|
||||||
|
|
||||||
|
|
||||||
@triton.autotune(
|
@triton.autotune(
|
||||||
configs=[
|
configs=[
|
||||||
# basic configs for compute-bound matmuls
|
# basic configs for compute-bound matmuls
|
||||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||||
|
@ -59,12 +64,12 @@ def get_configs_io_bound():
|
||||||
'perf_model': estimate_matmul_time,
|
'perf_model': estimate_matmul_time,
|
||||||
'top_k': 10
|
'top_k': 10
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@triton.heuristics({
|
@triton.heuristics({
|
||||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||||
})
|
})
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr,
|
def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr,
|
||||||
stride_am, stride_ak,
|
stride_am, stride_ak,
|
||||||
stride_bk, stride_bn,
|
stride_bk, stride_bn,
|
||||||
stride_cm, stride_cn,
|
stride_cm, stride_cn,
|
||||||
|
@ -130,7 +135,7 @@ def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M,
|
||||||
tl.atomic_add(C, acc, mask=mask)
|
tl.atomic_add(C, acc, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
|
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
|
||||||
divfactor = 1. / (127. * 127.)
|
divfactor = 1. / (127. * 127.)
|
||||||
|
|
||||||
has_bias = 0 if bias is None else 1
|
has_bias = 0 if bias is None else 1
|
||||||
|
|
|
@ -1,14 +1,20 @@
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
import triton
|
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||||
import triton.language as tl
|
|
||||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
|
||||||
|
|
||||||
# This kernel does fused columnwise quantization and transpose.
|
if not is_triton_available():
|
||||||
|
def quantize_columnwise_and_transpose(x: torch.Tensor): return None
|
||||||
|
else:
|
||||||
|
|
||||||
# TODO: autotune this better.
|
import triton
|
||||||
@triton.autotune(
|
import triton.language as tl
|
||||||
|
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||||
|
|
||||||
|
# This kernel does fused columnwise quantization and transpose.
|
||||||
|
|
||||||
|
# TODO: autotune this better.
|
||||||
|
@triton.autotune(
|
||||||
configs=[
|
configs=[
|
||||||
triton.Config({}, num_stages=1),
|
triton.Config({}, num_stages=1),
|
||||||
triton.Config({}, num_stages=2),
|
triton.Config({}, num_stages=2),
|
||||||
|
@ -26,9 +32,9 @@ from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_tim
|
||||||
triton.Config({}, num_warps=8),
|
triton.Config({}, num_warps=8),
|
||||||
],
|
],
|
||||||
key=['n_elements']
|
key=['n_elements']
|
||||||
)
|
)
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _quantize_columnwise_and_transpose(
|
def _quantize_columnwise_and_transpose(
|
||||||
x_ptr,
|
x_ptr,
|
||||||
output_ptr,
|
output_ptr,
|
||||||
output_maxs,
|
output_maxs,
|
||||||
|
@ -36,7 +42,7 @@ def _quantize_columnwise_and_transpose(
|
||||||
M : tl.constexpr, N : tl.constexpr,
|
M : tl.constexpr, N : tl.constexpr,
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
P2: tl.constexpr,
|
P2: tl.constexpr,
|
||||||
):
|
):
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
block_start = pid
|
block_start = pid
|
||||||
p2_arange = tl.arange(0, P2)
|
p2_arange = tl.arange(0, P2)
|
||||||
|
@ -53,7 +59,7 @@ def _quantize_columnwise_and_transpose(
|
||||||
tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask)
|
tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask)
|
||||||
tl.store(output_maxs + pid, max_val)
|
tl.store(output_maxs + pid, max_val)
|
||||||
|
|
||||||
def quantize_columnwise_and_transpose(x: torch.Tensor):
|
def quantize_columnwise_and_transpose(x: torch.Tensor):
|
||||||
M, N = x.shape
|
M, N = x.shape
|
||||||
output = torch.empty(N, M, device=x.device, dtype=torch.int8)
|
output = torch.empty(N, M, device=x.device, dtype=torch.int8)
|
||||||
output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16)
|
output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16)
|
||||||
|
|
|
@ -1,27 +1,34 @@
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
import triton
|
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||||
import triton.language as tl
|
|
||||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
|
||||||
|
|
||||||
# global quantize
|
if not is_triton_available():
|
||||||
@triton.autotune(
|
def quantize_global_transpose(input): return None
|
||||||
|
def quantize_global(x: torch.Tensor): return None
|
||||||
|
else:
|
||||||
|
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||||
|
|
||||||
|
# global quantize
|
||||||
|
@triton.autotune(
|
||||||
configs=[
|
configs=[
|
||||||
triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4),
|
triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4),
|
||||||
triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1),
|
triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1),
|
||||||
|
|
||||||
],
|
],
|
||||||
key=['n_elements']
|
key=['n_elements']
|
||||||
)
|
)
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _quantize_global(
|
def _quantize_global(
|
||||||
x_ptr,
|
x_ptr,
|
||||||
absmax_inv_ptr,
|
absmax_inv_ptr,
|
||||||
output_ptr,
|
output_ptr,
|
||||||
n_elements,
|
n_elements,
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
):
|
):
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
block_start = pid * BLOCK_SIZE
|
block_start = pid * BLOCK_SIZE
|
||||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
@ -31,7 +38,7 @@ def _quantize_global(
|
||||||
output = tl.libdevice.llrint(127. * (x * absmax_inv))
|
output = tl.libdevice.llrint(127. * (x * absmax_inv))
|
||||||
tl.store(output_ptr + offsets, output, mask=mask)
|
tl.store(output_ptr + offsets, output, mask=mask)
|
||||||
|
|
||||||
def quantize_global(x: torch.Tensor):
|
def quantize_global(x: torch.Tensor):
|
||||||
absmax = x.abs().max().unsqueeze(0)
|
absmax = x.abs().max().unsqueeze(0)
|
||||||
absmax_inv = 1./ absmax
|
absmax_inv = 1./ absmax
|
||||||
output = torch.empty(*x.shape, device='cuda', dtype=torch.int8)
|
output = torch.empty(*x.shape, device='cuda', dtype=torch.int8)
|
||||||
|
@ -42,8 +49,8 @@ def quantize_global(x: torch.Tensor):
|
||||||
return output, absmax
|
return output, absmax
|
||||||
|
|
||||||
|
|
||||||
# global quantize and transpose
|
# global quantize and transpose
|
||||||
@triton.autotune(
|
@triton.autotune(
|
||||||
configs=[
|
configs=[
|
||||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
|
||||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
|
||||||
|
@ -51,9 +58,9 @@ def quantize_global(x: torch.Tensor):
|
||||||
# ...
|
# ...
|
||||||
],
|
],
|
||||||
key=['M', 'N']
|
key=['M', 'N']
|
||||||
)
|
)
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N,
|
def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N,
|
||||||
BLOCK_M : tl.constexpr,
|
BLOCK_M : tl.constexpr,
|
||||||
BLOCK_N : tl.constexpr,
|
BLOCK_N : tl.constexpr,
|
||||||
GROUP_M : tl.constexpr):
|
GROUP_M : tl.constexpr):
|
||||||
|
@ -84,7 +91,7 @@ def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, strid
|
||||||
|
|
||||||
tl.store(B, output, mask=mask)
|
tl.store(B, output, mask=mask)
|
||||||
|
|
||||||
def quantize_global_transpose(input):
|
def quantize_global_transpose(input):
|
||||||
absmax = input.abs().max().unsqueeze(0)
|
absmax = input.abs().max().unsqueeze(0)
|
||||||
absmax_inv = 1./ absmax
|
absmax_inv = 1./ absmax
|
||||||
M, N = input.shape
|
M, N = input.shape
|
||||||
|
|
|
@ -1,14 +1,21 @@
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
|
||||||
|
|
||||||
# rowwise quantize
|
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||||
|
|
||||||
# TODO: autotune this better.
|
if not is_triton_available():
|
||||||
@triton.autotune(
|
def quantize_rowwise(x: torch.Tensor): return None
|
||||||
|
else:
|
||||||
|
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||||
|
|
||||||
|
# rowwise quantize
|
||||||
|
|
||||||
|
# TODO: autotune this better.
|
||||||
|
@triton.autotune(
|
||||||
configs=[
|
configs=[
|
||||||
triton.Config({}, num_stages=1, num_warps=8),
|
triton.Config({}, num_stages=1, num_warps=8),
|
||||||
triton.Config({}, num_stages=2, num_warps=8),
|
triton.Config({}, num_stages=2, num_warps=8),
|
||||||
|
@ -24,16 +31,16 @@ from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_tim
|
||||||
triton.Config({}, num_warps=8),
|
triton.Config({}, num_warps=8),
|
||||||
],
|
],
|
||||||
key=['n_elements']
|
key=['n_elements']
|
||||||
)
|
)
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _quantize_rowwise(
|
def _quantize_rowwise(
|
||||||
x_ptr,
|
x_ptr,
|
||||||
output_ptr,
|
output_ptr,
|
||||||
output_maxs,
|
output_maxs,
|
||||||
n_elements,
|
n_elements,
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
P2: tl.constexpr,
|
P2: tl.constexpr,
|
||||||
):
|
):
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
block_start = pid * BLOCK_SIZE
|
block_start = pid * BLOCK_SIZE
|
||||||
arange = tl.arange(0, P2)
|
arange = tl.arange(0, P2)
|
||||||
|
@ -47,7 +54,7 @@ def _quantize_rowwise(
|
||||||
tl.store(output_ptr + offsets, output, mask=row_mask)
|
tl.store(output_ptr + offsets, output, mask=row_mask)
|
||||||
tl.store(output_maxs + pid, max_val)
|
tl.store(output_maxs + pid, max_val)
|
||||||
|
|
||||||
def quantize_rowwise(x: torch.Tensor):
|
def quantize_rowwise(x: torch.Tensor):
|
||||||
output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)
|
output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)
|
||||||
output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)
|
output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user