diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 5d80df9..dcbc423 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -3,18 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import cuda_setup, utils +from . import cuda_setup, utils, research from .autograd._functions import ( MatmulLtState, bmm_cublas, matmul, matmul_cublas, mm_cublas, - matmul_fp8, - matmul_mixed, - matmul_fp8_global, - matmul_fp4, - matmul_fp8_mixed, ) from .cextension import COMPILED_WITH_CUDA from .nn import modules diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index b7da7b0..cfab4a4 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -390,518 +390,6 @@ class MatMul8bitLt(torch.autograd.Function): return grad_A, grad_B, None, grad_bias, None -class MatMulFP8(torch.autograd.Function): - # forward is the same, but we added the fallback for pre-turing GPUs - # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") - - @staticmethod - def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024): - # default of pytorch behavior if inputs are empty - ctx.is_empty = False - if prod(A.shape) == 0: - ctx.is_empty = True - ctx.A = A - ctx.B = B - - B_shape = B.shape - if A.shape[-1] == B_shape[0]: - return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device) - else: - return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) - - # 1. Dequantize - # 2. MatmulnN - cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=bsz) - fp8A = F.dequantize_blockwise(cA, state, blocksize=bsz).to(A.dtype) - - cB, state = F.quantize(B.float(), code=fw_code) - fp8B = F.dequantize(cB, state).to(B.dtype) - - output = torch.matmul(fp8A, fp8B) - - # output is half - - # 3. Save state - ctx.fw_code = fw_code - ctx.bw_code = bw_code - ctx.bsz = bsz - ctx.bsz2 = bsz2 - ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype - - if any(ctx.needs_input_grad[:2]): - # NOTE: we send back A, and re-quant. - ctx.tensors = (A, fp8B) - else: - ctx.tensors = (None, None) - - return output - - @staticmethod - def backward(ctx, grad_output): - if ctx.is_empty: - return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None - - req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad - A, B = ctx.tensors - - grad_A, grad_B = None, None - - cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz2) - fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz2).to(grad_output.dtype) - - cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code) - fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype) - - # grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() - # fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector') - # fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose - # fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2]) - - # not supported by PyTorch. TODO: create work-around - if req_gradA: - grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype) - - if req_gradB: - At = A.transpose(2, 1).contiguous() - cA, state = F.quantize(At.float(), code=ctx.fw_code) - fp8At = F.dequantize(cA, state).to(A.dtype) - grad_B = torch.matmul(fp8At.to(fp8out_2.dtype), fp8out_2).to(B.dtype) - - return grad_A, grad_B, None, None, None, None, None - -class MatMulFP8Mixed(torch.autograd.Function): - # forward is the same, but we added the fallback for pre-turing GPUs - # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") - - @staticmethod - def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024): - # default of pytorch behavior if inputs are empty - ctx.is_empty = False - if prod(A.shape) == 0: - ctx.is_empty = True - ctx.A = A - ctx.B = B - - B_shape = B.shape - if A.shape[-1] == B_shape[0]: - return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device) - else: - return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) - - # 1. Dequantize - # 2. MatmulnN - cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=bsz) - fp8A = F.dequantize_blockwise(cA, state, blocksize=bsz).to(A.dtype) - - cB, state = F.quantize(B.float(), code=fw_code) - fp8B = F.dequantize(cB, state).to(B.dtype) - - output = torch.matmul(fp8A, fp8B) - - # output is half - - # 3. Save state - ctx.fw_code = fw_code - ctx.bw_code = bw_code - ctx.bsz = bsz - ctx.bsz2 = bsz2 - ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype - - if any(ctx.needs_input_grad[:2]): - # NOTE: we send back A, and re-quant. - ctx.tensors = (A, fp8B) - else: - ctx.tensors = (None, None) - - return output - - @staticmethod - def backward(ctx, grad_output): - if ctx.is_empty: - return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None - - req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad - A, B = ctx.tensors - - grad_A, grad_B = None, None - - # TODO: Fix blocksize to be output_dim - cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz2) - fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz2).to(grad_output.dtype) - - # cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code) - # fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype) - - # grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() - # fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector') - # fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose - # fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2]) - - # not supported by PyTorch. TODO: create work-around - if req_gradA: - grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype) - - if req_gradB: - At = A.transpose(2, 1).contiguous() - # cA, state = F.quantize(At.float(), code=ctx.fw_code) - # fp8At = F.dequantize(cA, state).to(A.dtype) - grad_B = torch.matmul(At.to(grad_output.dtype), grad_output).to(B.dtype) - - return grad_A, grad_B, None, None, None, None, None - -class MatMulFP4(torch.autograd.Function): - # forward is the same, but we added the fallback for pre-turing GPUs - # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") - - @staticmethod - def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024): - # default of pytorch behavior if inputs are empty - ctx.is_empty = False - if prod(A.shape) == 0: - ctx.is_empty = True - ctx.A = A - ctx.B = B - - B_shape = B.shape - if A.shape[-1] == B_shape[0]: - return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device) - else: - return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) - - # 1. Dequantize - # 2. MatmulnN - cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=bsz) - fp8A = F.dequantize_blockwise(cA, state, blocksize=bsz).to(A.dtype) - - cB, state = F.quantize(B.float(), code=fw_code) - fp8B = F.dequantize(cB, state).to(B.dtype) - - output = torch.matmul(fp8A, fp8B) - - # output is half - - # 3. Save state - ctx.fw_code = fw_code - ctx.bw_code = bw_code - ctx.bsz = bsz - ctx.bsz2 = bsz2 - ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype - - if any(ctx.needs_input_grad[:2]): - # NOTE: we send back A, and re-quant. - ctx.tensors = (A, fp8B) - else: - ctx.tensors = (None, None) - - return output - - @staticmethod - def backward(ctx, grad_output): - if ctx.is_empty: - return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None - - req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad - A, B = ctx.tensors - - grad_A, grad_B = None, None - - # TODO: Fix blocksize to be output_dim - cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz2) - fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz2).to(grad_output.dtype) - - cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code) - fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype) - - # grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() - # fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector') - # fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose - # fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2]) - - # not supported by PyTorch. TODO: create work-around - if req_gradA: - grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype) - - if req_gradB: - At = A.transpose(2, 1).contiguous() - cA, state = F.quantize(At.float(), code=ctx.bw_code) - fp8At = F.dequantize(cA, state).to(A.dtype) - grad_B = torch.matmul(fp8At.to(fp8out_2.dtype), fp8out_2).to(B.dtype) - - return grad_A, grad_B, None, None, None, None, None - - - -class MatMulFP8Global(torch.autograd.Function): - # forward is the same, but we added the fallback for pre-turing GPUs - # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") - - @staticmethod - def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024): - # default of pytorch behavior if inputs are empty - ctx.is_empty = False - if prod(A.shape) == 0: - ctx.is_empty = True - ctx.A = A - ctx.B = B - - B_shape = B.shape - if A.shape[-1] == B_shape[0]: - return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device) - else: - return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) - - # 1. Dequantize - # 2. MatmulnN - cA, state = F.quantize(A.float(), code=fw_code) - fp8A = F.dequantize(cA, state).to(A.dtype) - - cB, state = F.quantize(B.float(), code=fw_code) - fp8B = F.dequantize(cB, state).to(B.dtype) - - output = torch.matmul(fp8A, fp8B) - - # output is half - - # 3. Save state - ctx.fw_code = fw_code - ctx.bw_code = bw_code - ctx.bsz = bsz - ctx.bsz2 = bsz2 - ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype - - if any(ctx.needs_input_grad[:2]): - # NOTE: we send back A, and re-quant. - ctx.tensors = (A, fp8B) - else: - ctx.tensors = (None, None) - - return output - - @staticmethod - def backward(ctx, grad_output): - if ctx.is_empty: - return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None - - req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad - A, B = ctx.tensors - - grad_A, grad_B = None, None - - # TODO: Fix blocksize to be output_dim - cgrad_out, state = F.quantize(grad_output.float(), code=ctx.bw_code) - fp8out = F.dequantize(cgrad_out, state).to(grad_output.dtype) - - # cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code) - # fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype) - - # grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() - # fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector') - # fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose - # fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2]) - - # not supported by PyTorch. TODO: create work-around - if req_gradA: - grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype) - - if req_gradB: - At = A.transpose(2, 1).contiguous() - cA, state = F.quantize(At.float(), code=ctx.fw_code) - fp8At = F.dequantize(cA, state).to(A.dtype) - grad_B = torch.matmul(fp8At.to(fp8out.dtype), fp8out).to(B.dtype) - - return grad_A, grad_B, None, None, None, None, None - - -class MatMul8bitMixed(torch.autograd.Function): - @staticmethod - def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): - # default to pytorch behavior if inputs are empty - ctx.is_empty = False - if prod(A.shape) == 0: - ctx.is_empty = True - ctx.A = A - ctx.B = B - ctx.bias = bias - if A.shape[-1] == B.shape[0]: - return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device) - else: - return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device) - - # 1. Quantize A - # 2. Quantize B - # 3. Matmul - # 4. Mixed-precision decomposition matmul - # 5. Save state - formatB = state.formatB - input_shape = A.shape - if state.outlier_pool is None: - state.outlier_pool = GlobalOutlierPooler.get_instance() - - # Cast A to fp16 - if A.dtype != torch.float16: - warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") - - # 1. Quantize A - if len(A.shape) == 3: - A = A.view(-1, A.shape[-1]).contiguous() - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( - A.to(torch.float16), threshold=state.threshold - ) - - if state.threshold > 0.0 and coo_tensorA is not None: - if state.has_fp16_weights: - idx = torch.unique(coo_tensorA.colidx).long() - CA[:, idx] = 0 - CAt[:, idx] = 0 - subA = A[:, idx] - state.subB = B[:, idx].t().contiguous() - state.idx = idx - else: - if state.CxB is None: - # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions - # we also need to convert it to the turing/ampere format - state.CxB, state.SB = F.transform(state.CB, to_order=formatB) - else: - #print('A shape', A.shape) - if not state.has_fp16_weights and state.CxB is None: - state.CxB, state.SB = F.transform(state.CB, to_order=formatB) - subA = None - - # 2. Quantize B - if state.has_fp16_weights: - #print('B shape', B.shape) - has_grad = True if (getattr(B, "grad", None) is not None) else False - is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1) - if is_transposed: - B = B.contiguous() - - if (state.is_training and not has_grad) or state.CxB is None: - state.reset_grads() - ( - CB, - state.CBt, - state.SCB, - state.SCBt, - coo_tensorB, - ) = F.double_quant(B.to(torch.float16)) - state.CxB, state.SB = F.transform(CB, to_order=formatB) - else: - has_grad = False - - if coo_tensorA is not None and not state.has_fp16_weights: - # extract outliers - - outlier_idx = torch.unique(coo_tensorA.colidx) - state.idx = outlier_idx - # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) - # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: - # # do not use pool for 2nd FFN layer - # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) - # else: - # state.idx = outlier_idx - outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) - state.subB = ( - (outliers * state.SCB.view(-1, 1) / 127.0) - .t() - .contiguous() - .to(A.dtype) - ) - CA[:, state.idx.long()] = 0 - CAt[:, state.idx.long()] = 0 - subA = A[:, state.idx.long()] - - shapeB = state.SB[0] - - if len(input_shape) == 3: - output_shape = (input_shape[0], input_shape[1], shapeB[0]) - else: - output_shape = (input_shape[0], shapeB[0]) - - # 3. Matmul - C32A, SA = F.transform(CA, "col32") - out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) - # we apply the fused bias here - - if bias is None or bias.dtype == torch.float16: - output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) - output = output.to(A.dtype) - else: # apply bias separately - output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) - output = output.to(A.dtype).add_(bias) - - # 4. Mixed-precision decomposition matmul - if coo_tensorA is not None and subA is not None: - output += torch.matmul(subA, state.subB) - - # 5. Save state - ctx.state = state - - ctx.formatB = formatB - ctx.grad_shape = input_shape - ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype - - if any(ctx.needs_input_grad[:2]): - ctx.tensors = (CAt, subA, A) - ctx.tensor_states = (SCAt, state.idx) - else: - ctx.tensors = [None, None, None] - ctx.tensor_states = (None, None) - ctx.save_for_backward(None, None) - - - clone_func = torch.clone if len(output_shape) == 3 else lambda x : x - return clone_func(output.view(output_shape)) - - @staticmethod - def backward(ctx, grad_output): - if ctx.is_empty: - bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) - return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None - req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad - CAt, subA, A = ctx.tensors - SCAt, idx = ctx.tensor_states - formatB = ctx.formatB - state = ctx.state - grad_A = grad_B = grad_bias = None - - if req_gradBias: - # compute grad_bias first before changing grad_output dtype - grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) - - # Cast grad_output to fp16 - if len(grad_output.shape) == 3: - grad_output = grad_output.reshape( - -1, grad_output.shape[-1] - ).contiguous() - - Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) - - if req_gradB: - # print('back A shape', A.shape) - # print('grad output t shape', grad_output.t().shape) - grad_B = torch.matmul(grad_output.t(), A) - - if req_gradA: - if state.CBt is not None: - C32grad, Sgrad = F.transform(Cgrad, "col32") - if state.CxBt is None: - state.CxBt, state.SBt = F.transform( - state.CBt, to_order=formatB, transpose=True - ) - # print('back B shape', state.CxBt.shape) - # print('back grad shape', C32grad.shape) - gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) - grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) - - elif state.CB is not None: - CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0)) - grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) - else: - raise Exception('State must contain either CBt or CB matrix for backward') - - return grad_A, grad_B, None, grad_bias, None - - def matmul( A: tensor, B: tensor, @@ -914,31 +402,3 @@ def matmul( if threshold > 0.0: state.threshold = threshold return MatMul8bitLt.apply(A, B, out, bias, state) - - -def matmul_fp8(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1): - return MatMulFP8.apply(A, B, out, fw_code, bw_code, bsz, bsz2) - -def matmul_fp8_global(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1): - return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2) - -def matmul_fp8_mixed(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1): - return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2) - - -def matmul_fp4(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1): - return MatMulFP4.apply(A, B, out, fw_code, bw_code, bsz, bsz2) - - -def matmul_mixed( - A: tensor, - B: tensor, - out: tensor = None, - state: MatmulLtState = None, - threshold=0.0, - bias=None -): - state = state or MatmulLtState() - if threshold > 0.0: - state.threshold = threshold - return MatMul8bitMixed.apply(A, B, out, bias, state) diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index c6141ad..51bccbc 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -2,5 +2,5 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast, Linear8bitLt2, Linear8bitLtMixed, LinearFP8Global, LinearFP4, LinearFP8Mixed +from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast, Linear8bitLtMixed, LinearFP8Global, LinearFP4, LinearFP8Mixed from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorized, StandardLinear diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 9cdcb4a..7150378 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -163,55 +163,6 @@ class OutlierAwareLinear(nn.Linear): return self.forward_with_outliers(x, self.outlier_dim) -class Fake4bitLinear(OutlierAwareLinear): - def __init__(self, input_features, output_features, bias=True, codebook=bnb.functional.create_fp8_map(True, 3, 0, total_bits=4)): - super().__init__(input_features, output_features, bias) - self.codebook = codebook - - def quantize_weight(self, w, outlier_idx): - if outlier_idx.numel() > 0: - subw = w[:, outlier_idx].clone() - w[:, outlier_idx] = 0 - wdtype = w.dtype - code = self.codebook.to(w.device) - cw, state = bnb.functional.quantize_blockwise(w, code=code, blocksize=64) - w = bnb.functional.dequantize_blockwise(cw, state, blocksize=64) - w = w.to(wdtype) - if outlier_idx.numel() > 0: - w[:, outlier_idx] = subw - self.is_quantized = True - return w - - def forward_with_outliers(self, x, outlier_idx): - dims = torch.abs(x> 4).sum(dim=list(range(len(x.shape)-1))) - outlier_idx2 = torch.where(dims > 0)[0] - outlier_idx = torch.cat([outlier_idx, outlier_idx2]).unique() - n = x.shape[-1] - idx = torch.arange(n, device=x.device) - idx[outlier_idx] = -1 - inverse_idx = torch.where(idx >= 0)[0] - if outlier_idx.numel() > 0: - subx = x[..., outlier_idx].clone() - #print(1, subx, 1) - #x[..., outlier_idx] = 0 - inverse_x = x[...,inverse_idx] - xdtype = x.dtype - #code = bnb.functional.create_fp8_map(True, 4-3, 2, 4).to(x.device) - #code = bnb.functional.create_quantile_map(x, 4).to(x.device) - code = bnb.functional.create_dynamic_map(True, total_bits=4.0).to(x.device) - c, state = bnb.functional.quantize_blockwise(inverse_x, code=code, blocksize=64) - inverse_x = bnb.functional.dequantize_blockwise(c, state, blocksize=64) - #c, state = bnb.functional.quantize_blockwise(x, code=code, blocksize=64) - #x = bnb.functional.dequantize_blockwise(c, state, blocksize=64) - x = x.to(xdtype) - x[..., inverse_idx] = inverse_x.to(x.dtype) - #if outlier_idx.numel() > 0: - #x[..., outlier_idx] = subx - - return torch.nn.functional.linear(x, self.weight, self.bias) - - - class Int8Params(torch.nn.Parameter): def __new__( cls, @@ -346,67 +297,6 @@ class Linear8bitLt(nn.Linear): return out -# Not in use for now... -class Linear8bitLt2(nn.Linear): - def __init__( - self, - input_features, - output_features, - bias=True, - has_fp16_weights=True, - memory_efficient_backward=False, - threshold=0.0, - index=None, - ): - super().__init__( - input_features, output_features, bias - ) - self.state = bnb.MatmulLtState() - self.index = index - - self.state.threshold = threshold - self.state.has_fp16_weights = has_fp16_weights - self.state.memory_efficient_backward = memory_efficient_backward - if threshold > 0.0 and not has_fp16_weights: - self.state.use_pool = True - - self.weight = Int8Params( - self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights - ) - - def init_8bit_state(self): - self.state.CB = self.weight.CB - self.state.SCB = self.weight.SCB - self.weight.CB = None - self.weight.SCB = None - - def forward(self, x): - self.state.is_training = self.training - - if self.weight.CB is not None: - self.init_8bit_state() - - # weights are cast automatically as Int8Params, but the bias has to be cast manually - # if self.bias is not None and self.bias.dtype != torch.float16: - # self.bias.data = self.bias.data.half() - - #out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias - out = bnb.matmul(x, self.weight, bias=None, state=self.state) + self.bias - #out = torch.matmul(x.half(), W.half().t()) + self.bias - - if not self.state.has_fp16_weights: - if not self.state.memory_efficient_backward and self.state.CB is not None: - # we converted 8-bit row major to turing/ampere format in the first inference pass - # we no longer need the row-major weight - del self.state.CB - self.weight.data = self.state.CxB - elif self.state.memory_efficient_backward and self.state.CxB is not None: - # For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass. - # Thus, we delete CxB from the state. - del self.state.CxB - - return out - class Linear8bitLtMixed(nn.Linear): def __init__( self, @@ -508,7 +398,7 @@ class LinearFP8(nn.Linear): self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) - out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) + out = bnb.research.matmul_fp8(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) if self.bias is not None: out += self.bias @@ -534,7 +424,7 @@ class LinearFP8Mixed(nn.Linear): self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) - out = bnb.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) + out = bnb.research.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) if self.bias is not None: out += self.bias @@ -638,4 +528,4 @@ class LinearFP4(nn.Linear): if self.bias is not None: out += self.bias - return out \ No newline at end of file + return out diff --git a/bitsandbytes/nn/triton_based_modules.py b/bitsandbytes/nn/triton_based_modules.py index ffb1866..61e9053 100644 --- a/bitsandbytes/nn/triton_based_modules.py +++ b/bitsandbytes/nn/triton_based_modules.py @@ -3,12 +3,12 @@ import torch.nn as nn import time from functools import partial -from .triton_utils.v0.dequantize_rowwise import dequantize_rowwise -from .triton_utils.v0.quantize_rowwise import quantize_rowwise -from .triton_utils.v0.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose -from .triton_utils.v0.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize -from .triton_utils.v0.quantize_global import quantize_global, quantize_global_transpose -from .triton_utils.v0.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze +from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise +from bitsandbytes.triton.quantize_rowwise import quantize_rowwise +from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose +from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize +from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose +from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze class _switchback_global(torch.autograd.Function): @@ -55,7 +55,7 @@ class _switchback_global(torch.autograd.Function): grad_bias = G.sum(dim=0) return grad_X, grad_W, grad_bias - + class _switchback_vectorrize(torch.autograd.Function): @staticmethod @@ -74,7 +74,7 @@ class _switchback_vectorrize(torch.autograd.Function): return int8_matmul_rowwise_dequantize( X_int8, W_int8.t(), state_X, state_W, bias ).view(*X_3D.size()[:-1], -1) - + @staticmethod def backward(ctx, G_3D): X, W = ctx.save_for_backward @@ -98,7 +98,7 @@ class _switchback_vectorrize(torch.autograd.Function): grad_bias = G.sum(dim=0) return grad_X, grad_W, grad_bias - + class _switchback_global_mem_efficient(torch.autograd.Function): @staticmethod @@ -149,11 +149,11 @@ class _switchback_global_mem_efficient(torch.autograd.Function): class SwitchBackLinear(nn.Linear): def __init__( - self, - in_features: int, - out_features: int, + self, + in_features: int, + out_features: int, bias: bool = True, - device=None, + device=None, dtype=None, vectorize: bool = False, mem_efficient : bool = False, @@ -186,7 +186,7 @@ class SwitchBackLinear(nn.Linear): W_int8, state_W = quantize_rowwise(self.weight) else: W_int8, state_W = quantize_global(self.weight) - + self.register_buffer("W_int8", W_int8) self.register_buffer("state_W", state_W) @@ -199,7 +199,7 @@ class SwitchBackLinear(nn.Linear): # If it hasn't been "prepared for eval", run the standard forward pass. if not hasattr(self, "W_int8"): return self._fn.apply(x, self.weight, self.bias) - + # Otherwise, use pre-computed weights. X = x.view(-1, x.size(-1)) X_int8, state_X = quantize_rowwise(X) @@ -250,4 +250,3 @@ class StandardLinear(nn.Linear): def forward(self, x): return StandardLinearFunction.apply(x, self.weight, self.bias) - diff --git a/bitsandbytes/nn/triton_utils/v0/__init__.py b/bitsandbytes/nn/triton_utils/v0/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/bitsandbytes/nn/triton_utils/v0/dequantize_rowwise.py b/bitsandbytes/nn/triton_utils/v0/dequantize_rowwise.py deleted file mode 100644 index 7e31483..0000000 --- a/bitsandbytes/nn/triton_utils/v0/dequantize_rowwise.py +++ /dev/null @@ -1,58 +0,0 @@ -import math -import torch -import time -import triton -import triton.language as tl -from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time - -# rowwise quantize - -# TODO: autotune this better. -@triton.autotune( - configs=[ - triton.Config({}, num_stages=1, num_warps=8), - triton.Config({}, num_stages=2, num_warps=8), - triton.Config({}, num_stages=4, num_warps=8), - triton.Config({}, num_stages=8, num_warps=8), - triton.Config({}, num_stages=1), - triton.Config({}, num_stages=2), - triton.Config({}, num_stages=4), - triton.Config({}, num_stages=8), - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - ], - key=['n_elements'] -) -@triton.jit -def _dequantize_rowwise( - x_ptr, - state_x, - output_ptr, - inv_127, - n_elements, - BLOCK_SIZE: tl.constexpr, - P2: tl.constexpr, -): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - arange = tl.arange(0, P2) - offsets = block_start + arange - row_mask = arange < BLOCK_SIZE - x = tl.load(x_ptr + offsets, mask=row_mask) - max_val = tl.load(state_x + pid) - output = max_val * x * inv_127 - tl.store(output_ptr + offsets, output, mask=row_mask) - - -def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): - output = torch.empty(*x.shape, device=x.device, dtype=torch.float16) - - P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) - - assert x.is_cuda and output.is_cuda - n_elements = output.numel() - grid = lambda meta: (x.shape[0],) - _dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) - return output diff --git a/bitsandbytes/nn/triton_utils/v0/int8_matmul_mixed_dequanitze.py b/bitsandbytes/nn/triton_utils/v0/int8_matmul_mixed_dequanitze.py deleted file mode 100644 index 69d4b0c..0000000 --- a/bitsandbytes/nn/triton_utils/v0/int8_matmul_mixed_dequanitze.py +++ /dev/null @@ -1,158 +0,0 @@ -import torch - -import triton -import triton.language as tl -from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time - - -# This is a matmul kernel based on triton.ops.matmul -# It is modified to support rowwise quantized input and global quantized weight -# It's purpose is fused matmul then dequantize -# It does support bias. - -def init_to_zero(name): - return lambda nargs: nargs[name].zero_() - -def get_configs_io_bound(): - configs = [] - for num_stages in [2, 3, 4, 5, 6]: - for block_m in [16, 32]: - for block_k in [32, 64]: - for block_n in [32, 64, 128, 256]: - num_warps = 2 if block_n <= 64 else 4 - configs.append( - triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, - num_stages=num_stages, num_warps=num_warps)) - # split_k - for split_k in [2, 4, 8, 16]: - configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, - num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) - return configs - - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), - # good for int8 - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), - ] + get_configs_io_bound(), - key=['M', 'N', 'K'], - prune_configs_by={ - 'early_config_prune': early_config_prune, - 'perf_model': estimate_matmul_time, - 'top_k': 10 - }, -) -@triton.heuristics({ - 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, -}) -@triton.jit -def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, - ACC_TYPE: tl.constexpr - ): - # matrix multiplication - pid = tl.program_id(0) - pid_z = tl.program_id(1) - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - # do matrix multiplication - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) - # pointers - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - w_factor = tl.load(state_w_ptr) - x_factor = tl.load(state_x_ptr + ram)[:, None] - - # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) - for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): - if EVEN_K: - a = tl.load(A) - b = tl.load(B) - else: - k_remaining = K - k * (BLOCK_K * SPLIT_K) - a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) - b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) - acc += tl.dot(a, b) - A += BLOCK_K * SPLIT_K * stride_ak - B += BLOCK_K * SPLIT_K * stride_bk - - acc = (w_factor * (x_factor * (acc * divfactor))) - acc = acc.to(C.dtype.element_ty) - - # conditionally add bias - if has_bias: - bias = tl.load(bias + rn).to(C.dtype.element_ty) - acc = acc + bias[None, :] - - C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) - mask = (rm < M)[:, None] & (rn < N)[None, :] - # handles write-back with reduction-splitting - if SPLIT_K == 1: - tl.store(C, acc, mask=mask) - else: - tl.atomic_add(C, acc, mask=mask) - - -def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): - device = a.device - divfactor = 1. / (127. * 127.) - has_bias = 0 if bias is None else 1 - # handle non-contiguous inputs if necessary - if a.stride(0) > 1 and a.stride(1) > 1: - a = a.contiguous() - if b.stride(0) > 1 and b.stride(1) > 1: - b = b.contiguous() - # checks constraints - assert a.shape[1] == b.shape[0], "incompatible dimensions" - M, K = a.shape - _, N = b.shape - # allocates output - c = torch.empty((M, N), device=device, dtype=torch.float16) - # accumulator types - ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - # launch int8_matmul_mixed_dequantize kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) - _int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - GROUP_M=8, ACC_TYPE=ACC_TYPE) - return c diff --git a/bitsandbytes/nn/triton_utils/v0/int8_matmul_rowwise_dequantize.py b/bitsandbytes/nn/triton_utils/v0/int8_matmul_rowwise_dequantize.py deleted file mode 100644 index 4af054b..0000000 --- a/bitsandbytes/nn/triton_utils/v0/int8_matmul_rowwise_dequantize.py +++ /dev/null @@ -1,159 +0,0 @@ -import torch - -import triton -import triton.language as tl -from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time - -# This is a matmul kernel based on triton.ops.matmul -# It is modified to support rowwise quantized input and columnwise quantized weight -# It's purpose is fused matmul then dequantize -# It does support bias. - -def init_to_zero(name): - return lambda nargs: nargs[name].zero_() - - -def get_configs_io_bound(): - configs = [] - for num_stages in [2, 3, 4, 5, 6]: - for block_m in [16, 32]: - for block_k in [32, 64]: - for block_n in [32, 64, 128, 256]: - num_warps = 2 if block_n <= 64 else 4 - configs.append( - triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, - num_stages=num_stages, num_warps=num_warps)) - # split_k - for split_k in [2, 4, 8, 16]: - configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, - num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) - return configs - - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), - # good for int8 - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), - ] + get_configs_io_bound(), - key=['M', 'N', 'K'], - prune_configs_by={ - 'early_config_prune': early_config_prune, - 'perf_model': estimate_matmul_time, - 'top_k': 10 - }, -) -@triton.heuristics({ - 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, -}) -@triton.jit -def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, - ACC_TYPE: tl.constexpr - ): - # matrix multiplication - pid = tl.program_id(0) - pid_z = tl.program_id(1) - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - # do matrix multiplication - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) - # pointers - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - w_factor = tl.load(state_w_ptr + rbn)[None, :] - x_factor = tl.load(state_x_ptr + ram)[:, None] - - # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) - for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): - if EVEN_K: - a = tl.load(A) - b = tl.load(B) - else: - k_remaining = K - k * (BLOCK_K * SPLIT_K) - a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) - b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) - acc += tl.dot(a, b) - A += BLOCK_K * SPLIT_K * stride_ak - B += BLOCK_K * SPLIT_K * stride_bk - - acc = (w_factor * (x_factor * (acc * divfactor))) - acc = acc.to(C.dtype.element_ty) - - if has_bias: - bias = tl.load(bias + rn).to(C.dtype.element_ty) - acc = acc + bias[None, :] - - C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) - mask = (rm < M)[:, None] & (rn < N)[None, :] - # handles write-back with reduction-splitting - if SPLIT_K == 1: - tl.store(C, acc, mask=mask) - else: - tl.atomic_add(C, acc, mask=mask) - - -def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): - divfactor = 1. / (127. * 127.) - - has_bias = 0 if bias is None else 1 - - device = a.device - # handle non-contiguous inputs if necessary - if a.stride(0) > 1 and a.stride(1) > 1: - a = a.contiguous() - if b.stride(0) > 1 and b.stride(1) > 1: - b = b.contiguous() - # checks constraints - assert a.shape[1] == b.shape[0], "incompatible dimensions" - M, K = a.shape - _, N = b.shape - # allocates output - c = torch.empty((M, N), device=device, dtype=torch.float16) - # accumulator types - ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - # launch int8_matmul_rowwise_dequantize kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) - _int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - GROUP_M=8, ACC_TYPE=ACC_TYPE) - return c diff --git a/bitsandbytes/nn/triton_utils/v0/quantize_columnwise_and_transpose.py b/bitsandbytes/nn/triton_utils/v0/quantize_columnwise_and_transpose.py deleted file mode 100644 index 4e53475..0000000 --- a/bitsandbytes/nn/triton_utils/v0/quantize_columnwise_and_transpose.py +++ /dev/null @@ -1,68 +0,0 @@ -import math -import torch -import time -import triton -import triton.language as tl -from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time - -# This kernel does fused columnwise quantization and transpose. - -# TODO: autotune this better. -@triton.autotune( - configs=[ - triton.Config({}, num_stages=1), - triton.Config({}, num_stages=2), - triton.Config({}, num_stages=4), - triton.Config({}, num_stages=8), - triton.Config({}, num_stages=16), - triton.Config({}, num_stages=1, num_warps=8), - triton.Config({}, num_stages=2, num_warps=8), - triton.Config({}, num_stages=4, num_warps=8), - triton.Config({}, num_stages=8, num_warps=8), - triton.Config({}, num_stages=16, num_warps=8), - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - ], - key=['n_elements'] -) -@triton.jit -def _quantize_columnwise_and_transpose( - x_ptr, - output_ptr, - output_maxs, - n_elements, - M : tl.constexpr, N : tl.constexpr, - BLOCK_SIZE: tl.constexpr, - P2: tl.constexpr, -): - pid = tl.program_id(axis=0) - block_start = pid - p2_arange = tl.arange(0, P2) - p2_arange_mask = p2_arange < M - arange = p2_arange * N - offsets = block_start + arange - x = tl.load(x_ptr + offsets, mask=p2_arange_mask) - abs_x = tl.abs(x) - max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0) - output = tl.libdevice.llrint(127. * (x / max_val)) - - new_start = pid * M - new_offsets = new_start + p2_arange - tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask) - tl.store(output_maxs + pid, max_val) - -def quantize_columnwise_and_transpose(x: torch.Tensor): - M, N = x.shape - output = torch.empty(N, M, device=x.device, dtype=torch.int8) - output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16) - - P2 = int(2 ** (math.ceil(math.log2(M)))) - - assert x.is_cuda and output.is_cuda - n_elements = output.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - _quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2) - return output, output_maxs - diff --git a/bitsandbytes/nn/triton_utils/v0/quantize_global.py b/bitsandbytes/nn/triton_utils/v0/quantize_global.py deleted file mode 100644 index 229721c..0000000 --- a/bitsandbytes/nn/triton_utils/v0/quantize_global.py +++ /dev/null @@ -1,100 +0,0 @@ -import math -import torch -import time -import triton -import triton.language as tl -from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time - -# global quantize -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4), - triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1), - - ], - key=['n_elements'] -) -@triton.jit -def _quantize_global( - x_ptr, - absmax_inv_ptr, - output_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(x_ptr + offsets, mask=mask) - absmax_inv = tl.load(absmax_inv_ptr) - output = tl.libdevice.llrint(127. * (x * absmax_inv)) - tl.store(output_ptr + offsets, output, mask=mask) - -def quantize_global(x: torch.Tensor): - absmax = x.abs().max().unsqueeze(0) - absmax_inv = 1./ absmax - output = torch.empty(*x.shape, device='cuda', dtype=torch.int8) - assert x.is_cuda and output.is_cuda - n_elements = output.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - _quantize_global[grid](x, absmax_inv, output, n_elements) - return output, absmax - - -# global quantize and transpose -@triton.autotune( - configs=[ - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), - - # ... - ], - key=['M', 'N'] -) -@triton.jit -def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N, - BLOCK_M : tl.constexpr, - BLOCK_N : tl.constexpr, - GROUP_M : tl.constexpr): - pid = tl.program_id(0) - grid_m = (M + BLOCK_M - 1) // BLOCK_M - grid_n = (N + BLOCK_N - 1) // BLOCK_N - - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // group_size - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an) - mask = (rm < M)[:, None] & (rn < N)[None, :] - a = tl.load(A, mask=mask) - absmax_inv = tl.load(absmax_inv_ptr) - - # rematerialize to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn) - mask = (rm < M)[:, None] & (rn < N)[None, :] - - output = tl.libdevice.llrint(127. * (a * absmax_inv)) - - tl.store(B, output, mask=mask) - -def quantize_global_transpose(input): - absmax = input.abs().max().unsqueeze(0) - absmax_inv = 1./ absmax - M, N = input.shape - out = torch.empty(N, M, device='cuda', dtype=torch.int8) - - assert out.size(0) == N and out.size(1) == M - assert input.stride(0) == 1 or input.stride(1) == 1 - assert out.stride(0) == 1 or out.stride(1) == 1 - - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),) - _quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N) - return out, absmax - diff --git a/bitsandbytes/nn/triton_utils/v0/quantize_rowwise.py b/bitsandbytes/nn/triton_utils/v0/quantize_rowwise.py deleted file mode 100644 index d956647..0000000 --- a/bitsandbytes/nn/triton_utils/v0/quantize_rowwise.py +++ /dev/null @@ -1,61 +0,0 @@ -import math -import torch -import time -import triton -import triton.language as tl -from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time - -# rowwise quantize - -# TODO: autotune this better. -@triton.autotune( - configs=[ - triton.Config({}, num_stages=1, num_warps=8), - triton.Config({}, num_stages=2, num_warps=8), - triton.Config({}, num_stages=4, num_warps=8), - triton.Config({}, num_stages=8, num_warps=8), - triton.Config({}, num_stages=1), - triton.Config({}, num_stages=2), - triton.Config({}, num_stages=4), - triton.Config({}, num_stages=8), - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - ], - key=['n_elements'] -) -@triton.jit -def _quantize_rowwise( - x_ptr, - output_ptr, - output_maxs, - n_elements, - BLOCK_SIZE: tl.constexpr, - P2: tl.constexpr, -): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - arange = tl.arange(0, P2) - offsets = block_start + arange - row_mask = arange < BLOCK_SIZE - x = tl.load(x_ptr + offsets, mask=row_mask) - - abs_x = tl.abs(x) - max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) - output = tl.libdevice.llrint(127. * (x / max_val)) - tl.store(output_ptr + offsets, output, mask=row_mask) - tl.store(output_maxs + pid, max_val) - -def quantize_rowwise(x: torch.Tensor): - output = torch.empty(*x.shape, device=x.device, dtype=torch.int8) - output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16) - - P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) - - assert x.is_cuda and output.is_cuda - n_elements = output.numel() - grid = lambda meta: (x.shape[0],) - _quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) - return output, output_maxs - diff --git a/speed_benchmark/speed_benchmark.py b/speed_benchmark/speed_benchmark.py index eccc455..9ad9911 100644 --- a/speed_benchmark/speed_benchmark.py +++ b/speed_benchmark/speed_benchmark.py @@ -4,11 +4,11 @@ import time import torch import torch.nn as nn -from bitsandbytes.nn.triton_utils.v0.quantize_rowwise import quantize_rowwise -from bitsandbytes.nn.triton_utils.v0.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose -from bitsandbytes.nn.triton_utils.v0.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize -from bitsandbytes.nn.triton_utils.v0.quantize_global import quantize_global, quantize_global_transpose -from bitsandbytes.nn.triton_utils.v0.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze +from bitsandbytes.triton.quantize_rowwise import quantize_rowwise +from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose +from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize +from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose +from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze # KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large. diff --git a/tests/test_autograd.py b/tests/test_autograd.py index d05b4a6..ac2ae05 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -239,8 +239,8 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist() dim2.append(0) decomp = [0.0, 6.0] -funcs = [(torch.matmul, bnb.matmul_mixed)] -str_funcs = ["matmul"] +funcs = [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)] +str_funcs = ["matmullt", 'switchback_bnb'] req_grad = [(False, False), (True, False), (True, True), (False, True)] req_grad = list(product([True, False], repeat=3)) req_grad_str = [] @@ -441,7 +441,7 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist() dim2.append(0) -funcs = [(torch.matmul, bnb.matmul_fp8)] +funcs = [(torch.matmul, bnb.research.matmul_fp8)] str_funcs = ["matmul"] req_grad = list(product([True, False], repeat=3)) req_grad_str = [] diff --git a/tests/test_triton.py b/tests/test_triton.py index 2ec34fb..7f56a49 100644 --- a/tests/test_triton.py +++ b/tests/test_triton.py @@ -5,6 +5,7 @@ from bitsandbytes.nn.triton_based_modules import SwitchBackLinear from bitsandbytes.nn import Linear8bitLt +@pytest.mark.skipif(not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, reason="This test requires a GPU with compute capability 8.0 or higher.") @pytest.mark.parametrize("vectorrize", [False, True]) def test_switchback(vectorrize): for dim in [83, 17, 128]: @@ -26,6 +27,7 @@ def test_switchback(vectorrize): out_standard = standard(x1) (2**10 * out_standard.abs().mean()).backward() + print(x2.dtype) out_sb = switchback(x2) (2**10 * out_sb.abs().mean()).backward()