Refactored triton into its own folder. Refactored fp8 matmuls.
This commit is contained in:
parent
d677a71607
commit
ec1ea63711
|
@ -3,18 +3,13 @@
|
|||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from . import cuda_setup, utils
|
||||
from . import cuda_setup, utils, research
|
||||
from .autograd._functions import (
|
||||
MatmulLtState,
|
||||
bmm_cublas,
|
||||
matmul,
|
||||
matmul_cublas,
|
||||
mm_cublas,
|
||||
matmul_fp8,
|
||||
matmul_mixed,
|
||||
matmul_fp8_global,
|
||||
matmul_fp4,
|
||||
matmul_fp8_mixed,
|
||||
)
|
||||
from .cextension import COMPILED_WITH_CUDA
|
||||
from .nn import modules
|
||||
|
|
|
@ -390,518 +390,6 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
|
||||
return grad_A, grad_B, None, grad_bias, None
|
||||
|
||||
class MatMulFP8(torch.autograd.Function):
|
||||
# forward is the same, but we added the fallback for pre-turing GPUs
|
||||
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024):
|
||||
# default of pytorch behavior if inputs are empty
|
||||
ctx.is_empty = False
|
||||
if prod(A.shape) == 0:
|
||||
ctx.is_empty = True
|
||||
ctx.A = A
|
||||
ctx.B = B
|
||||
|
||||
B_shape = B.shape
|
||||
if A.shape[-1] == B_shape[0]:
|
||||
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
|
||||
else:
|
||||
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
|
||||
|
||||
# 1. Dequantize
|
||||
# 2. MatmulnN
|
||||
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=bsz)
|
||||
fp8A = F.dequantize_blockwise(cA, state, blocksize=bsz).to(A.dtype)
|
||||
|
||||
cB, state = F.quantize(B.float(), code=fw_code)
|
||||
fp8B = F.dequantize(cB, state).to(B.dtype)
|
||||
|
||||
output = torch.matmul(fp8A, fp8B)
|
||||
|
||||
# output is half
|
||||
|
||||
# 3. Save state
|
||||
ctx.fw_code = fw_code
|
||||
ctx.bw_code = bw_code
|
||||
ctx.bsz = bsz
|
||||
ctx.bsz2 = bsz2
|
||||
ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype
|
||||
|
||||
if any(ctx.needs_input_grad[:2]):
|
||||
# NOTE: we send back A, and re-quant.
|
||||
ctx.tensors = (A, fp8B)
|
||||
else:
|
||||
ctx.tensors = (None, None)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.is_empty:
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None
|
||||
|
||||
req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad
|
||||
A, B = ctx.tensors
|
||||
|
||||
grad_A, grad_B = None, None
|
||||
|
||||
cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz2)
|
||||
fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz2).to(grad_output.dtype)
|
||||
|
||||
cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
|
||||
fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)
|
||||
|
||||
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
||||
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
|
||||
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
|
||||
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
|
||||
|
||||
# not supported by PyTorch. TODO: create work-around
|
||||
if req_gradA:
|
||||
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)
|
||||
|
||||
if req_gradB:
|
||||
At = A.transpose(2, 1).contiguous()
|
||||
cA, state = F.quantize(At.float(), code=ctx.fw_code)
|
||||
fp8At = F.dequantize(cA, state).to(A.dtype)
|
||||
grad_B = torch.matmul(fp8At.to(fp8out_2.dtype), fp8out_2).to(B.dtype)
|
||||
|
||||
return grad_A, grad_B, None, None, None, None, None
|
||||
|
||||
class MatMulFP8Mixed(torch.autograd.Function):
|
||||
# forward is the same, but we added the fallback for pre-turing GPUs
|
||||
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024):
|
||||
# default of pytorch behavior if inputs are empty
|
||||
ctx.is_empty = False
|
||||
if prod(A.shape) == 0:
|
||||
ctx.is_empty = True
|
||||
ctx.A = A
|
||||
ctx.B = B
|
||||
|
||||
B_shape = B.shape
|
||||
if A.shape[-1] == B_shape[0]:
|
||||
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
|
||||
else:
|
||||
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
|
||||
|
||||
# 1. Dequantize
|
||||
# 2. MatmulnN
|
||||
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=bsz)
|
||||
fp8A = F.dequantize_blockwise(cA, state, blocksize=bsz).to(A.dtype)
|
||||
|
||||
cB, state = F.quantize(B.float(), code=fw_code)
|
||||
fp8B = F.dequantize(cB, state).to(B.dtype)
|
||||
|
||||
output = torch.matmul(fp8A, fp8B)
|
||||
|
||||
# output is half
|
||||
|
||||
# 3. Save state
|
||||
ctx.fw_code = fw_code
|
||||
ctx.bw_code = bw_code
|
||||
ctx.bsz = bsz
|
||||
ctx.bsz2 = bsz2
|
||||
ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype
|
||||
|
||||
if any(ctx.needs_input_grad[:2]):
|
||||
# NOTE: we send back A, and re-quant.
|
||||
ctx.tensors = (A, fp8B)
|
||||
else:
|
||||
ctx.tensors = (None, None)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.is_empty:
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None
|
||||
|
||||
req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad
|
||||
A, B = ctx.tensors
|
||||
|
||||
grad_A, grad_B = None, None
|
||||
|
||||
# TODO: Fix blocksize to be output_dim
|
||||
cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz2)
|
||||
fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz2).to(grad_output.dtype)
|
||||
|
||||
# cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
|
||||
# fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)
|
||||
|
||||
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
||||
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
|
||||
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
|
||||
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
|
||||
|
||||
# not supported by PyTorch. TODO: create work-around
|
||||
if req_gradA:
|
||||
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)
|
||||
|
||||
if req_gradB:
|
||||
At = A.transpose(2, 1).contiguous()
|
||||
# cA, state = F.quantize(At.float(), code=ctx.fw_code)
|
||||
# fp8At = F.dequantize(cA, state).to(A.dtype)
|
||||
grad_B = torch.matmul(At.to(grad_output.dtype), grad_output).to(B.dtype)
|
||||
|
||||
return grad_A, grad_B, None, None, None, None, None
|
||||
|
||||
class MatMulFP4(torch.autograd.Function):
|
||||
# forward is the same, but we added the fallback for pre-turing GPUs
|
||||
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024):
|
||||
# default of pytorch behavior if inputs are empty
|
||||
ctx.is_empty = False
|
||||
if prod(A.shape) == 0:
|
||||
ctx.is_empty = True
|
||||
ctx.A = A
|
||||
ctx.B = B
|
||||
|
||||
B_shape = B.shape
|
||||
if A.shape[-1] == B_shape[0]:
|
||||
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
|
||||
else:
|
||||
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
|
||||
|
||||
# 1. Dequantize
|
||||
# 2. MatmulnN
|
||||
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=bsz)
|
||||
fp8A = F.dequantize_blockwise(cA, state, blocksize=bsz).to(A.dtype)
|
||||
|
||||
cB, state = F.quantize(B.float(), code=fw_code)
|
||||
fp8B = F.dequantize(cB, state).to(B.dtype)
|
||||
|
||||
output = torch.matmul(fp8A, fp8B)
|
||||
|
||||
# output is half
|
||||
|
||||
# 3. Save state
|
||||
ctx.fw_code = fw_code
|
||||
ctx.bw_code = bw_code
|
||||
ctx.bsz = bsz
|
||||
ctx.bsz2 = bsz2
|
||||
ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype
|
||||
|
||||
if any(ctx.needs_input_grad[:2]):
|
||||
# NOTE: we send back A, and re-quant.
|
||||
ctx.tensors = (A, fp8B)
|
||||
else:
|
||||
ctx.tensors = (None, None)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.is_empty:
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None
|
||||
|
||||
req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad
|
||||
A, B = ctx.tensors
|
||||
|
||||
grad_A, grad_B = None, None
|
||||
|
||||
# TODO: Fix blocksize to be output_dim
|
||||
cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz2)
|
||||
fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz2).to(grad_output.dtype)
|
||||
|
||||
cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
|
||||
fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)
|
||||
|
||||
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
||||
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
|
||||
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
|
||||
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
|
||||
|
||||
# not supported by PyTorch. TODO: create work-around
|
||||
if req_gradA:
|
||||
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)
|
||||
|
||||
if req_gradB:
|
||||
At = A.transpose(2, 1).contiguous()
|
||||
cA, state = F.quantize(At.float(), code=ctx.bw_code)
|
||||
fp8At = F.dequantize(cA, state).to(A.dtype)
|
||||
grad_B = torch.matmul(fp8At.to(fp8out_2.dtype), fp8out_2).to(B.dtype)
|
||||
|
||||
return grad_A, grad_B, None, None, None, None, None
|
||||
|
||||
|
||||
|
||||
class MatMulFP8Global(torch.autograd.Function):
|
||||
# forward is the same, but we added the fallback for pre-turing GPUs
|
||||
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024):
|
||||
# default of pytorch behavior if inputs are empty
|
||||
ctx.is_empty = False
|
||||
if prod(A.shape) == 0:
|
||||
ctx.is_empty = True
|
||||
ctx.A = A
|
||||
ctx.B = B
|
||||
|
||||
B_shape = B.shape
|
||||
if A.shape[-1] == B_shape[0]:
|
||||
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
|
||||
else:
|
||||
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
|
||||
|
||||
# 1. Dequantize
|
||||
# 2. MatmulnN
|
||||
cA, state = F.quantize(A.float(), code=fw_code)
|
||||
fp8A = F.dequantize(cA, state).to(A.dtype)
|
||||
|
||||
cB, state = F.quantize(B.float(), code=fw_code)
|
||||
fp8B = F.dequantize(cB, state).to(B.dtype)
|
||||
|
||||
output = torch.matmul(fp8A, fp8B)
|
||||
|
||||
# output is half
|
||||
|
||||
# 3. Save state
|
||||
ctx.fw_code = fw_code
|
||||
ctx.bw_code = bw_code
|
||||
ctx.bsz = bsz
|
||||
ctx.bsz2 = bsz2
|
||||
ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype
|
||||
|
||||
if any(ctx.needs_input_grad[:2]):
|
||||
# NOTE: we send back A, and re-quant.
|
||||
ctx.tensors = (A, fp8B)
|
||||
else:
|
||||
ctx.tensors = (None, None)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.is_empty:
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None
|
||||
|
||||
req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad
|
||||
A, B = ctx.tensors
|
||||
|
||||
grad_A, grad_B = None, None
|
||||
|
||||
# TODO: Fix blocksize to be output_dim
|
||||
cgrad_out, state = F.quantize(grad_output.float(), code=ctx.bw_code)
|
||||
fp8out = F.dequantize(cgrad_out, state).to(grad_output.dtype)
|
||||
|
||||
# cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
|
||||
# fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)
|
||||
|
||||
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
||||
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
|
||||
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
|
||||
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
|
||||
|
||||
# not supported by PyTorch. TODO: create work-around
|
||||
if req_gradA:
|
||||
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)
|
||||
|
||||
if req_gradB:
|
||||
At = A.transpose(2, 1).contiguous()
|
||||
cA, state = F.quantize(At.float(), code=ctx.fw_code)
|
||||
fp8At = F.dequantize(cA, state).to(A.dtype)
|
||||
grad_B = torch.matmul(fp8At.to(fp8out.dtype), fp8out).to(B.dtype)
|
||||
|
||||
return grad_A, grad_B, None, None, None, None, None
|
||||
|
||||
|
||||
class MatMul8bitMixed(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
|
||||
# default to pytorch behavior if inputs are empty
|
||||
ctx.is_empty = False
|
||||
if prod(A.shape) == 0:
|
||||
ctx.is_empty = True
|
||||
ctx.A = A
|
||||
ctx.B = B
|
||||
ctx.bias = bias
|
||||
if A.shape[-1] == B.shape[0]:
|
||||
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device)
|
||||
else:
|
||||
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device)
|
||||
|
||||
# 1. Quantize A
|
||||
# 2. Quantize B
|
||||
# 3. Matmul
|
||||
# 4. Mixed-precision decomposition matmul
|
||||
# 5. Save state
|
||||
formatB = state.formatB
|
||||
input_shape = A.shape
|
||||
if state.outlier_pool is None:
|
||||
state.outlier_pool = GlobalOutlierPooler.get_instance()
|
||||
|
||||
# Cast A to fp16
|
||||
if A.dtype != torch.float16:
|
||||
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
|
||||
|
||||
# 1. Quantize A
|
||||
if len(A.shape) == 3:
|
||||
A = A.view(-1, A.shape[-1]).contiguous()
|
||||
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
|
||||
A.to(torch.float16), threshold=state.threshold
|
||||
)
|
||||
|
||||
if state.threshold > 0.0 and coo_tensorA is not None:
|
||||
if state.has_fp16_weights:
|
||||
idx = torch.unique(coo_tensorA.colidx).long()
|
||||
CA[:, idx] = 0
|
||||
CAt[:, idx] = 0
|
||||
subA = A[:, idx]
|
||||
state.subB = B[:, idx].t().contiguous()
|
||||
state.idx = idx
|
||||
else:
|
||||
if state.CxB is None:
|
||||
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
|
||||
# we also need to convert it to the turing/ampere format
|
||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
||||
else:
|
||||
#print('A shape', A.shape)
|
||||
if not state.has_fp16_weights and state.CxB is None:
|
||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
||||
subA = None
|
||||
|
||||
# 2. Quantize B
|
||||
if state.has_fp16_weights:
|
||||
#print('B shape', B.shape)
|
||||
has_grad = True if (getattr(B, "grad", None) is not None) else False
|
||||
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
|
||||
if is_transposed:
|
||||
B = B.contiguous()
|
||||
|
||||
if (state.is_training and not has_grad) or state.CxB is None:
|
||||
state.reset_grads()
|
||||
(
|
||||
CB,
|
||||
state.CBt,
|
||||
state.SCB,
|
||||
state.SCBt,
|
||||
coo_tensorB,
|
||||
) = F.double_quant(B.to(torch.float16))
|
||||
state.CxB, state.SB = F.transform(CB, to_order=formatB)
|
||||
else:
|
||||
has_grad = False
|
||||
|
||||
if coo_tensorA is not None and not state.has_fp16_weights:
|
||||
# extract outliers
|
||||
|
||||
outlier_idx = torch.unique(coo_tensorA.colidx)
|
||||
state.idx = outlier_idx
|
||||
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
|
||||
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
|
||||
# # do not use pool for 2nd FFN layer
|
||||
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
|
||||
# else:
|
||||
# state.idx = outlier_idx
|
||||
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
|
||||
state.subB = (
|
||||
(outliers * state.SCB.view(-1, 1) / 127.0)
|
||||
.t()
|
||||
.contiguous()
|
||||
.to(A.dtype)
|
||||
)
|
||||
CA[:, state.idx.long()] = 0
|
||||
CAt[:, state.idx.long()] = 0
|
||||
subA = A[:, state.idx.long()]
|
||||
|
||||
shapeB = state.SB[0]
|
||||
|
||||
if len(input_shape) == 3:
|
||||
output_shape = (input_shape[0], input_shape[1], shapeB[0])
|
||||
else:
|
||||
output_shape = (input_shape[0], shapeB[0])
|
||||
|
||||
# 3. Matmul
|
||||
C32A, SA = F.transform(CA, "col32")
|
||||
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
|
||||
# we apply the fused bias here
|
||||
|
||||
if bias is None or bias.dtype == torch.float16:
|
||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
|
||||
output = output.to(A.dtype)
|
||||
else: # apply bias separately
|
||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
|
||||
output = output.to(A.dtype).add_(bias)
|
||||
|
||||
# 4. Mixed-precision decomposition matmul
|
||||
if coo_tensorA is not None and subA is not None:
|
||||
output += torch.matmul(subA, state.subB)
|
||||
|
||||
# 5. Save state
|
||||
ctx.state = state
|
||||
|
||||
ctx.formatB = formatB
|
||||
ctx.grad_shape = input_shape
|
||||
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
|
||||
|
||||
if any(ctx.needs_input_grad[:2]):
|
||||
ctx.tensors = (CAt, subA, A)
|
||||
ctx.tensor_states = (SCAt, state.idx)
|
||||
else:
|
||||
ctx.tensors = [None, None, None]
|
||||
ctx.tensor_states = (None, None)
|
||||
ctx.save_for_backward(None, None)
|
||||
|
||||
|
||||
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
|
||||
return clone_func(output.view(output_shape))
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.is_empty:
|
||||
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
|
||||
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
|
||||
CAt, subA, A = ctx.tensors
|
||||
SCAt, idx = ctx.tensor_states
|
||||
formatB = ctx.formatB
|
||||
state = ctx.state
|
||||
grad_A = grad_B = grad_bias = None
|
||||
|
||||
if req_gradBias:
|
||||
# compute grad_bias first before changing grad_output dtype
|
||||
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
|
||||
|
||||
# Cast grad_output to fp16
|
||||
if len(grad_output.shape) == 3:
|
||||
grad_output = grad_output.reshape(
|
||||
-1, grad_output.shape[-1]
|
||||
).contiguous()
|
||||
|
||||
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
|
||||
|
||||
if req_gradB:
|
||||
# print('back A shape', A.shape)
|
||||
# print('grad output t shape', grad_output.t().shape)
|
||||
grad_B = torch.matmul(grad_output.t(), A)
|
||||
|
||||
if req_gradA:
|
||||
if state.CBt is not None:
|
||||
C32grad, Sgrad = F.transform(Cgrad, "col32")
|
||||
if state.CxBt is None:
|
||||
state.CxBt, state.SBt = F.transform(
|
||||
state.CBt, to_order=formatB, transpose=True
|
||||
)
|
||||
# print('back B shape', state.CxBt.shape)
|
||||
# print('back grad shape', C32grad.shape)
|
||||
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
|
||||
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||
|
||||
elif state.CB is not None:
|
||||
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0))
|
||||
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||
else:
|
||||
raise Exception('State must contain either CBt or CB matrix for backward')
|
||||
|
||||
return grad_A, grad_B, None, grad_bias, None
|
||||
|
||||
|
||||
def matmul(
|
||||
A: tensor,
|
||||
B: tensor,
|
||||
|
@ -914,31 +402,3 @@ def matmul(
|
|||
if threshold > 0.0:
|
||||
state.threshold = threshold
|
||||
return MatMul8bitLt.apply(A, B, out, bias, state)
|
||||
|
||||
|
||||
def matmul_fp8(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
|
||||
return MatMulFP8.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
|
||||
|
||||
def matmul_fp8_global(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
|
||||
return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
|
||||
|
||||
def matmul_fp8_mixed(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
|
||||
return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
|
||||
|
||||
|
||||
def matmul_fp4(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
|
||||
return MatMulFP4.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
|
||||
|
||||
|
||||
def matmul_mixed(
|
||||
A: tensor,
|
||||
B: tensor,
|
||||
out: tensor = None,
|
||||
state: MatmulLtState = None,
|
||||
threshold=0.0,
|
||||
bias=None
|
||||
):
|
||||
state = state or MatmulLtState()
|
||||
if threshold > 0.0:
|
||||
state.threshold = threshold
|
||||
return MatMul8bitMixed.apply(A, B, out, bias, state)
|
||||
|
|
|
@ -2,5 +2,5 @@
|
|||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast, Linear8bitLt2, Linear8bitLtMixed, LinearFP8Global, LinearFP4, LinearFP8Mixed
|
||||
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast, Linear8bitLtMixed, LinearFP8Global, LinearFP4, LinearFP8Mixed
|
||||
from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorized, StandardLinear
|
||||
|
|
|
@ -163,55 +163,6 @@ class OutlierAwareLinear(nn.Linear):
|
|||
return self.forward_with_outliers(x, self.outlier_dim)
|
||||
|
||||
|
||||
class Fake4bitLinear(OutlierAwareLinear):
|
||||
def __init__(self, input_features, output_features, bias=True, codebook=bnb.functional.create_fp8_map(True, 3, 0, total_bits=4)):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.codebook = codebook
|
||||
|
||||
def quantize_weight(self, w, outlier_idx):
|
||||
if outlier_idx.numel() > 0:
|
||||
subw = w[:, outlier_idx].clone()
|
||||
w[:, outlier_idx] = 0
|
||||
wdtype = w.dtype
|
||||
code = self.codebook.to(w.device)
|
||||
cw, state = bnb.functional.quantize_blockwise(w, code=code, blocksize=64)
|
||||
w = bnb.functional.dequantize_blockwise(cw, state, blocksize=64)
|
||||
w = w.to(wdtype)
|
||||
if outlier_idx.numel() > 0:
|
||||
w[:, outlier_idx] = subw
|
||||
self.is_quantized = True
|
||||
return w
|
||||
|
||||
def forward_with_outliers(self, x, outlier_idx):
|
||||
dims = torch.abs(x> 4).sum(dim=list(range(len(x.shape)-1)))
|
||||
outlier_idx2 = torch.where(dims > 0)[0]
|
||||
outlier_idx = torch.cat([outlier_idx, outlier_idx2]).unique()
|
||||
n = x.shape[-1]
|
||||
idx = torch.arange(n, device=x.device)
|
||||
idx[outlier_idx] = -1
|
||||
inverse_idx = torch.where(idx >= 0)[0]
|
||||
if outlier_idx.numel() > 0:
|
||||
subx = x[..., outlier_idx].clone()
|
||||
#print(1, subx, 1)
|
||||
#x[..., outlier_idx] = 0
|
||||
inverse_x = x[...,inverse_idx]
|
||||
xdtype = x.dtype
|
||||
#code = bnb.functional.create_fp8_map(True, 4-3, 2, 4).to(x.device)
|
||||
#code = bnb.functional.create_quantile_map(x, 4).to(x.device)
|
||||
code = bnb.functional.create_dynamic_map(True, total_bits=4.0).to(x.device)
|
||||
c, state = bnb.functional.quantize_blockwise(inverse_x, code=code, blocksize=64)
|
||||
inverse_x = bnb.functional.dequantize_blockwise(c, state, blocksize=64)
|
||||
#c, state = bnb.functional.quantize_blockwise(x, code=code, blocksize=64)
|
||||
#x = bnb.functional.dequantize_blockwise(c, state, blocksize=64)
|
||||
x = x.to(xdtype)
|
||||
x[..., inverse_idx] = inverse_x.to(x.dtype)
|
||||
#if outlier_idx.numel() > 0:
|
||||
#x[..., outlier_idx] = subx
|
||||
|
||||
return torch.nn.functional.linear(x, self.weight, self.bias)
|
||||
|
||||
|
||||
|
||||
class Int8Params(torch.nn.Parameter):
|
||||
def __new__(
|
||||
cls,
|
||||
|
@ -346,67 +297,6 @@ class Linear8bitLt(nn.Linear):
|
|||
return out
|
||||
|
||||
|
||||
# Not in use for now...
|
||||
class Linear8bitLt2(nn.Linear):
|
||||
def __init__(
|
||||
self,
|
||||
input_features,
|
||||
output_features,
|
||||
bias=True,
|
||||
has_fp16_weights=True,
|
||||
memory_efficient_backward=False,
|
||||
threshold=0.0,
|
||||
index=None,
|
||||
):
|
||||
super().__init__(
|
||||
input_features, output_features, bias
|
||||
)
|
||||
self.state = bnb.MatmulLtState()
|
||||
self.index = index
|
||||
|
||||
self.state.threshold = threshold
|
||||
self.state.has_fp16_weights = has_fp16_weights
|
||||
self.state.memory_efficient_backward = memory_efficient_backward
|
||||
if threshold > 0.0 and not has_fp16_weights:
|
||||
self.state.use_pool = True
|
||||
|
||||
self.weight = Int8Params(
|
||||
self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
|
||||
)
|
||||
|
||||
def init_8bit_state(self):
|
||||
self.state.CB = self.weight.CB
|
||||
self.state.SCB = self.weight.SCB
|
||||
self.weight.CB = None
|
||||
self.weight.SCB = None
|
||||
|
||||
def forward(self, x):
|
||||
self.state.is_training = self.training
|
||||
|
||||
if self.weight.CB is not None:
|
||||
self.init_8bit_state()
|
||||
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||
# if self.bias is not None and self.bias.dtype != torch.float16:
|
||||
# self.bias.data = self.bias.data.half()
|
||||
|
||||
#out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
|
||||
out = bnb.matmul(x, self.weight, bias=None, state=self.state) + self.bias
|
||||
#out = torch.matmul(x.half(), W.half().t()) + self.bias
|
||||
|
||||
if not self.state.has_fp16_weights:
|
||||
if not self.state.memory_efficient_backward and self.state.CB is not None:
|
||||
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
||||
# we no longer need the row-major weight
|
||||
del self.state.CB
|
||||
self.weight.data = self.state.CxB
|
||||
elif self.state.memory_efficient_backward and self.state.CxB is not None:
|
||||
# For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
|
||||
# Thus, we delete CxB from the state.
|
||||
del self.state.CxB
|
||||
|
||||
return out
|
||||
|
||||
class Linear8bitLtMixed(nn.Linear):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -508,7 +398,7 @@ class LinearFP8(nn.Linear):
|
|||
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
|
||||
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
|
||||
|
||||
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
|
||||
out = bnb.research.matmul_fp8(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
|
||||
if self.bias is not None:
|
||||
out += self.bias
|
||||
|
||||
|
@ -534,7 +424,7 @@ class LinearFP8Mixed(nn.Linear):
|
|||
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
|
||||
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
|
||||
|
||||
out = bnb.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
|
||||
out = bnb.research.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
|
||||
if self.bias is not None:
|
||||
out += self.bias
|
||||
|
||||
|
@ -638,4 +528,4 @@ class LinearFP4(nn.Linear):
|
|||
if self.bias is not None:
|
||||
out += self.bias
|
||||
|
||||
return out
|
||||
return out
|
||||
|
|
|
@ -3,12 +3,12 @@ import torch.nn as nn
|
|||
import time
|
||||
from functools import partial
|
||||
|
||||
from .triton_utils.v0.dequantize_rowwise import dequantize_rowwise
|
||||
from .triton_utils.v0.quantize_rowwise import quantize_rowwise
|
||||
from .triton_utils.v0.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
|
||||
from .triton_utils.v0.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
|
||||
from .triton_utils.v0.quantize_global import quantize_global, quantize_global_transpose
|
||||
from .triton_utils.v0.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
|
||||
from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise
|
||||
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
|
||||
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
|
||||
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
|
||||
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
|
||||
from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
|
||||
|
||||
|
||||
class _switchback_global(torch.autograd.Function):
|
||||
|
@ -55,7 +55,7 @@ class _switchback_global(torch.autograd.Function):
|
|||
grad_bias = G.sum(dim=0)
|
||||
|
||||
return grad_X, grad_W, grad_bias
|
||||
|
||||
|
||||
class _switchback_vectorrize(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
|
@ -74,7 +74,7 @@ class _switchback_vectorrize(torch.autograd.Function):
|
|||
return int8_matmul_rowwise_dequantize(
|
||||
X_int8, W_int8.t(), state_X, state_W, bias
|
||||
).view(*X_3D.size()[:-1], -1)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, G_3D):
|
||||
X, W = ctx.save_for_backward
|
||||
|
@ -98,7 +98,7 @@ class _switchback_vectorrize(torch.autograd.Function):
|
|||
grad_bias = G.sum(dim=0)
|
||||
|
||||
return grad_X, grad_W, grad_bias
|
||||
|
||||
|
||||
class _switchback_global_mem_efficient(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
|
@ -149,11 +149,11 @@ class _switchback_global_mem_efficient(torch.autograd.Function):
|
|||
|
||||
class SwitchBackLinear(nn.Linear):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
device=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
vectorize: bool = False,
|
||||
mem_efficient : bool = False,
|
||||
|
@ -186,7 +186,7 @@ class SwitchBackLinear(nn.Linear):
|
|||
W_int8, state_W = quantize_rowwise(self.weight)
|
||||
else:
|
||||
W_int8, state_W = quantize_global(self.weight)
|
||||
|
||||
|
||||
self.register_buffer("W_int8", W_int8)
|
||||
self.register_buffer("state_W", state_W)
|
||||
|
||||
|
@ -199,7 +199,7 @@ class SwitchBackLinear(nn.Linear):
|
|||
# If it hasn't been "prepared for eval", run the standard forward pass.
|
||||
if not hasattr(self, "W_int8"):
|
||||
return self._fn.apply(x, self.weight, self.bias)
|
||||
|
||||
|
||||
# Otherwise, use pre-computed weights.
|
||||
X = x.view(-1, x.size(-1))
|
||||
X_int8, state_X = quantize_rowwise(X)
|
||||
|
@ -250,4 +250,3 @@ class StandardLinear(nn.Linear):
|
|||
|
||||
def forward(self, x):
|
||||
return StandardLinearFunction.apply(x, self.weight, self.bias)
|
||||
|
||||
|
|
|
@ -1,58 +0,0 @@
|
|||
import math
|
||||
import torch
|
||||
import time
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
# rowwise quantize
|
||||
|
||||
# TODO: autotune this better.
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_stages=1, num_warps=8),
|
||||
triton.Config({}, num_stages=2, num_warps=8),
|
||||
triton.Config({}, num_stages=4, num_warps=8),
|
||||
triton.Config({}, num_stages=8, num_warps=8),
|
||||
triton.Config({}, num_stages=1),
|
||||
triton.Config({}, num_stages=2),
|
||||
triton.Config({}, num_stages=4),
|
||||
triton.Config({}, num_stages=8),
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
],
|
||||
key=['n_elements']
|
||||
)
|
||||
@triton.jit
|
||||
def _dequantize_rowwise(
|
||||
x_ptr,
|
||||
state_x,
|
||||
output_ptr,
|
||||
inv_127,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
P2: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
arange = tl.arange(0, P2)
|
||||
offsets = block_start + arange
|
||||
row_mask = arange < BLOCK_SIZE
|
||||
x = tl.load(x_ptr + offsets, mask=row_mask)
|
||||
max_val = tl.load(state_x + pid)
|
||||
output = max_val * x * inv_127
|
||||
tl.store(output_ptr + offsets, output, mask=row_mask)
|
||||
|
||||
|
||||
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
|
||||
output = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
|
||||
|
||||
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
|
||||
|
||||
assert x.is_cuda and output.is_cuda
|
||||
n_elements = output.numel()
|
||||
grid = lambda meta: (x.shape[0],)
|
||||
_dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
|
||||
return output
|
|
@ -1,158 +0,0 @@
|
|||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
|
||||
# This is a matmul kernel based on triton.ops.matmul
|
||||
# It is modified to support rowwise quantized input and global quantized weight
|
||||
# It's purpose is fused matmul then dequantize
|
||||
# It does support bias.
|
||||
|
||||
def init_to_zero(name):
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
|
||||
def get_configs_io_bound():
|
||||
configs = []
|
||||
for num_stages in [2, 3, 4, 5, 6]:
|
||||
for block_m in [16, 32]:
|
||||
for block_k in [32, 64]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
num_warps = 2 if block_n <= 64 else 4
|
||||
configs.append(
|
||||
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||
num_stages=num_stages, num_warps=num_warps))
|
||||
# split_k
|
||||
for split_k in [2, 4, 8, 16]:
|
||||
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
return configs
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
# basic configs for compute-bound matmuls
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
# good for int8
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
] + get_configs_io_bound(),
|
||||
key=['M', 'N', 'K'],
|
||||
prune_configs_by={
|
||||
'early_config_prune': early_config_prune,
|
||||
'perf_model': estimate_matmul_time,
|
||||
'top_k': 10
|
||||
},
|
||||
)
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr
|
||||
):
|
||||
# matrix multiplication
|
||||
pid = tl.program_id(0)
|
||||
pid_z = tl.program_id(1)
|
||||
grid_m = tl.cdiv(M, BLOCK_M)
|
||||
grid_n = tl.cdiv(N, BLOCK_N)
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
# do matrix multiplication
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
# pointers
|
||||
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
w_factor = tl.load(state_w_ptr)
|
||||
x_factor = tl.load(state_x_ptr + ram)[:, None]
|
||||
|
||||
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||
if EVEN_K:
|
||||
a = tl.load(A)
|
||||
b = tl.load(B)
|
||||
else:
|
||||
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
||||
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
|
||||
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
|
||||
acc += tl.dot(a, b)
|
||||
A += BLOCK_K * SPLIT_K * stride_ak
|
||||
B += BLOCK_K * SPLIT_K * stride_bk
|
||||
|
||||
acc = (w_factor * (x_factor * (acc * divfactor)))
|
||||
acc = acc.to(C.dtype.element_ty)
|
||||
|
||||
# conditionally add bias
|
||||
if has_bias:
|
||||
bias = tl.load(bias + rn).to(C.dtype.element_ty)
|
||||
acc = acc + bias[None, :]
|
||||
|
||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
# handles write-back with reduction-splitting
|
||||
if SPLIT_K == 1:
|
||||
tl.store(C, acc, mask=mask)
|
||||
else:
|
||||
tl.atomic_add(C, acc, mask=mask)
|
||||
|
||||
|
||||
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias):
|
||||
device = a.device
|
||||
divfactor = 1. / (127. * 127.)
|
||||
has_bias = 0 if bias is None else 1
|
||||
# handle non-contiguous inputs if necessary
|
||||
if a.stride(0) > 1 and a.stride(1) > 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(0) > 1 and b.stride(1) > 1:
|
||||
b = b.contiguous()
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
M, K = a.shape
|
||||
_, N = b.shape
|
||||
# allocates output
|
||||
c = torch.empty((M, N), device=device, dtype=torch.float16)
|
||||
# accumulator types
|
||||
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
|
||||
# launch int8_matmul_mixed_dequantize kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||
_int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
GROUP_M=8, ACC_TYPE=ACC_TYPE)
|
||||
return c
|
|
@ -1,159 +0,0 @@
|
|||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
# This is a matmul kernel based on triton.ops.matmul
|
||||
# It is modified to support rowwise quantized input and columnwise quantized weight
|
||||
# It's purpose is fused matmul then dequantize
|
||||
# It does support bias.
|
||||
|
||||
def init_to_zero(name):
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
|
||||
|
||||
def get_configs_io_bound():
|
||||
configs = []
|
||||
for num_stages in [2, 3, 4, 5, 6]:
|
||||
for block_m in [16, 32]:
|
||||
for block_k in [32, 64]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
num_warps = 2 if block_n <= 64 else 4
|
||||
configs.append(
|
||||
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||
num_stages=num_stages, num_warps=num_warps))
|
||||
# split_k
|
||||
for split_k in [2, 4, 8, 16]:
|
||||
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
return configs
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
# basic configs for compute-bound matmuls
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
# good for int8
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
] + get_configs_io_bound(),
|
||||
key=['M', 'N', 'K'],
|
||||
prune_configs_by={
|
||||
'early_config_prune': early_config_prune,
|
||||
'perf_model': estimate_matmul_time,
|
||||
'top_k': 10
|
||||
},
|
||||
)
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr
|
||||
):
|
||||
# matrix multiplication
|
||||
pid = tl.program_id(0)
|
||||
pid_z = tl.program_id(1)
|
||||
grid_m = tl.cdiv(M, BLOCK_M)
|
||||
grid_n = tl.cdiv(N, BLOCK_N)
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
# do matrix multiplication
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
# pointers
|
||||
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
w_factor = tl.load(state_w_ptr + rbn)[None, :]
|
||||
x_factor = tl.load(state_x_ptr + ram)[:, None]
|
||||
|
||||
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||
if EVEN_K:
|
||||
a = tl.load(A)
|
||||
b = tl.load(B)
|
||||
else:
|
||||
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
||||
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
|
||||
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
|
||||
acc += tl.dot(a, b)
|
||||
A += BLOCK_K * SPLIT_K * stride_ak
|
||||
B += BLOCK_K * SPLIT_K * stride_bk
|
||||
|
||||
acc = (w_factor * (x_factor * (acc * divfactor)))
|
||||
acc = acc.to(C.dtype.element_ty)
|
||||
|
||||
if has_bias:
|
||||
bias = tl.load(bias + rn).to(C.dtype.element_ty)
|
||||
acc = acc + bias[None, :]
|
||||
|
||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
# handles write-back with reduction-splitting
|
||||
if SPLIT_K == 1:
|
||||
tl.store(C, acc, mask=mask)
|
||||
else:
|
||||
tl.atomic_add(C, acc, mask=mask)
|
||||
|
||||
|
||||
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
|
||||
divfactor = 1. / (127. * 127.)
|
||||
|
||||
has_bias = 0 if bias is None else 1
|
||||
|
||||
device = a.device
|
||||
# handle non-contiguous inputs if necessary
|
||||
if a.stride(0) > 1 and a.stride(1) > 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(0) > 1 and b.stride(1) > 1:
|
||||
b = b.contiguous()
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
M, K = a.shape
|
||||
_, N = b.shape
|
||||
# allocates output
|
||||
c = torch.empty((M, N), device=device, dtype=torch.float16)
|
||||
# accumulator types
|
||||
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
|
||||
# launch int8_matmul_rowwise_dequantize kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||
_int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
GROUP_M=8, ACC_TYPE=ACC_TYPE)
|
||||
return c
|
|
@ -1,68 +0,0 @@
|
|||
import math
|
||||
import torch
|
||||
import time
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
# This kernel does fused columnwise quantization and transpose.
|
||||
|
||||
# TODO: autotune this better.
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_stages=1),
|
||||
triton.Config({}, num_stages=2),
|
||||
triton.Config({}, num_stages=4),
|
||||
triton.Config({}, num_stages=8),
|
||||
triton.Config({}, num_stages=16),
|
||||
triton.Config({}, num_stages=1, num_warps=8),
|
||||
triton.Config({}, num_stages=2, num_warps=8),
|
||||
triton.Config({}, num_stages=4, num_warps=8),
|
||||
triton.Config({}, num_stages=8, num_warps=8),
|
||||
triton.Config({}, num_stages=16, num_warps=8),
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
],
|
||||
key=['n_elements']
|
||||
)
|
||||
@triton.jit
|
||||
def _quantize_columnwise_and_transpose(
|
||||
x_ptr,
|
||||
output_ptr,
|
||||
output_maxs,
|
||||
n_elements,
|
||||
M : tl.constexpr, N : tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
P2: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid
|
||||
p2_arange = tl.arange(0, P2)
|
||||
p2_arange_mask = p2_arange < M
|
||||
arange = p2_arange * N
|
||||
offsets = block_start + arange
|
||||
x = tl.load(x_ptr + offsets, mask=p2_arange_mask)
|
||||
abs_x = tl.abs(x)
|
||||
max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0)
|
||||
output = tl.libdevice.llrint(127. * (x / max_val))
|
||||
|
||||
new_start = pid * M
|
||||
new_offsets = new_start + p2_arange
|
||||
tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask)
|
||||
tl.store(output_maxs + pid, max_val)
|
||||
|
||||
def quantize_columnwise_and_transpose(x: torch.Tensor):
|
||||
M, N = x.shape
|
||||
output = torch.empty(N, M, device=x.device, dtype=torch.int8)
|
||||
output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16)
|
||||
|
||||
P2 = int(2 ** (math.ceil(math.log2(M))))
|
||||
|
||||
assert x.is_cuda and output.is_cuda
|
||||
n_elements = output.numel()
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
_quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2)
|
||||
return output, output_maxs
|
||||
|
|
@ -1,100 +0,0 @@
|
|||
import math
|
||||
import torch
|
||||
import time
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
# global quantize
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1),
|
||||
|
||||
],
|
||||
key=['n_elements']
|
||||
)
|
||||
@triton.jit
|
||||
def _quantize_global(
|
||||
x_ptr,
|
||||
absmax_inv_ptr,
|
||||
output_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
absmax_inv = tl.load(absmax_inv_ptr)
|
||||
output = tl.libdevice.llrint(127. * (x * absmax_inv))
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
def quantize_global(x: torch.Tensor):
|
||||
absmax = x.abs().max().unsqueeze(0)
|
||||
absmax_inv = 1./ absmax
|
||||
output = torch.empty(*x.shape, device='cuda', dtype=torch.int8)
|
||||
assert x.is_cuda and output.is_cuda
|
||||
n_elements = output.numel()
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
_quantize_global[grid](x, absmax_inv, output, n_elements)
|
||||
return output, absmax
|
||||
|
||||
|
||||
# global quantize and transpose
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
|
||||
|
||||
# ...
|
||||
],
|
||||
key=['M', 'N']
|
||||
)
|
||||
@triton.jit
|
||||
def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N,
|
||||
BLOCK_M : tl.constexpr,
|
||||
BLOCK_N : tl.constexpr,
|
||||
GROUP_M : tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // group_size
|
||||
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
a = tl.load(A, mask=mask)
|
||||
absmax_inv = tl.load(absmax_inv_ptr)
|
||||
|
||||
# rematerialize to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
|
||||
output = tl.libdevice.llrint(127. * (a * absmax_inv))
|
||||
|
||||
tl.store(B, output, mask=mask)
|
||||
|
||||
def quantize_global_transpose(input):
|
||||
absmax = input.abs().max().unsqueeze(0)
|
||||
absmax_inv = 1./ absmax
|
||||
M, N = input.shape
|
||||
out = torch.empty(N, M, device='cuda', dtype=torch.int8)
|
||||
|
||||
assert out.size(0) == N and out.size(1) == M
|
||||
assert input.stride(0) == 1 or input.stride(1) == 1
|
||||
assert out.stride(0) == 1 or out.stride(1) == 1
|
||||
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
|
||||
_quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N)
|
||||
return out, absmax
|
||||
|
|
@ -1,61 +0,0 @@
|
|||
import math
|
||||
import torch
|
||||
import time
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
# rowwise quantize
|
||||
|
||||
# TODO: autotune this better.
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_stages=1, num_warps=8),
|
||||
triton.Config({}, num_stages=2, num_warps=8),
|
||||
triton.Config({}, num_stages=4, num_warps=8),
|
||||
triton.Config({}, num_stages=8, num_warps=8),
|
||||
triton.Config({}, num_stages=1),
|
||||
triton.Config({}, num_stages=2),
|
||||
triton.Config({}, num_stages=4),
|
||||
triton.Config({}, num_stages=8),
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
],
|
||||
key=['n_elements']
|
||||
)
|
||||
@triton.jit
|
||||
def _quantize_rowwise(
|
||||
x_ptr,
|
||||
output_ptr,
|
||||
output_maxs,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
P2: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
arange = tl.arange(0, P2)
|
||||
offsets = block_start + arange
|
||||
row_mask = arange < BLOCK_SIZE
|
||||
x = tl.load(x_ptr + offsets, mask=row_mask)
|
||||
|
||||
abs_x = tl.abs(x)
|
||||
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
|
||||
output = tl.libdevice.llrint(127. * (x / max_val))
|
||||
tl.store(output_ptr + offsets, output, mask=row_mask)
|
||||
tl.store(output_maxs + pid, max_val)
|
||||
|
||||
def quantize_rowwise(x: torch.Tensor):
|
||||
output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)
|
||||
output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)
|
||||
|
||||
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
|
||||
|
||||
assert x.is_cuda and output.is_cuda
|
||||
n_elements = output.numel()
|
||||
grid = lambda meta: (x.shape[0],)
|
||||
_quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
|
||||
return output, output_maxs
|
||||
|
|
@ -4,11 +4,11 @@ import time
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from bitsandbytes.nn.triton_utils.v0.quantize_rowwise import quantize_rowwise
|
||||
from bitsandbytes.nn.triton_utils.v0.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
|
||||
from bitsandbytes.nn.triton_utils.v0.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
|
||||
from bitsandbytes.nn.triton_utils.v0.quantize_global import quantize_global, quantize_global_transpose
|
||||
from bitsandbytes.nn.triton_utils.v0.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
|
||||
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
|
||||
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
|
||||
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
|
||||
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
|
||||
from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
|
||||
|
||||
# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
|
||||
|
||||
|
|
|
@ -239,8 +239,8 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist()
|
|||
dim2.append(0)
|
||||
|
||||
decomp = [0.0, 6.0]
|
||||
funcs = [(torch.matmul, bnb.matmul_mixed)]
|
||||
str_funcs = ["matmul"]
|
||||
funcs = [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)]
|
||||
str_funcs = ["matmullt", 'switchback_bnb']
|
||||
req_grad = [(False, False), (True, False), (True, True), (False, True)]
|
||||
req_grad = list(product([True, False], repeat=3))
|
||||
req_grad_str = []
|
||||
|
@ -441,7 +441,7 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist()
|
|||
|
||||
dim2.append(0)
|
||||
|
||||
funcs = [(torch.matmul, bnb.matmul_fp8)]
|
||||
funcs = [(torch.matmul, bnb.research.matmul_fp8)]
|
||||
str_funcs = ["matmul"]
|
||||
req_grad = list(product([True, False], repeat=3))
|
||||
req_grad_str = []
|
||||
|
|
|
@ -5,6 +5,7 @@ from bitsandbytes.nn.triton_based_modules import SwitchBackLinear
|
|||
from bitsandbytes.nn import Linear8bitLt
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, reason="This test requires a GPU with compute capability 8.0 or higher.")
|
||||
@pytest.mark.parametrize("vectorrize", [False, True])
|
||||
def test_switchback(vectorrize):
|
||||
for dim in [83, 17, 128]:
|
||||
|
@ -26,6 +27,7 @@ def test_switchback(vectorrize):
|
|||
out_standard = standard(x1)
|
||||
(2**10 * out_standard.abs().mean()).backward()
|
||||
|
||||
print(x2.dtype)
|
||||
out_sb = switchback(x2)
|
||||
(2**10 * out_sb.abs().mean()).backward()
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user