From de535889348c5406eb34d9f7e0c362cadb113be5 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 1 Feb 2023 20:09:31 -0800 Subject: [PATCH] Added Int8 matmul support for all GPUs. Full backward support. --- Makefile | 4 +- bitsandbytes/autograd/__init__.py | 1 + bitsandbytes/autograd/_functions.py | 159 +++++++++++++++++++++------- bitsandbytes/nn/modules.py | 36 ++----- tests/test_linear8bitlt.py | 61 +++++++++++ tests/test_modules.py | 2 +- 6 files changed, 195 insertions(+), 68 deletions(-) create mode 100644 tests/test_linear8bitlt.py diff --git a/Makefile b/Makefile index 10eb779..7bee7ef 100644 --- a/Makefile +++ b/Makefile @@ -60,8 +60,8 @@ CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90 all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) + $(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env diff --git a/bitsandbytes/autograd/__init__.py b/bitsandbytes/autograd/__init__.py index e69de29..6b9a7e4 100644 --- a/bitsandbytes/autograd/__init__.py +++ b/bitsandbytes/autograd/__init__.py @@ -0,0 +1 @@ +from ._functions import undo_layout, get_inverse_transform_indices diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index a115437..376fb8a 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -2,6 +2,7 @@ import operator import warnings from dataclasses import dataclass from functools import reduce # Required in Python 3 +from typing import Tuple, Optional import torch @@ -14,6 +15,12 @@ def prod(iterable): tensor = torch.Tensor + +# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov: +# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py + + + """ This class pools outlier dimensions across layers. This is particularly important for small models where outlier features @@ -48,6 +55,51 @@ class GlobalOutlierPooler: return torch.Tensor(list(self.outliers)).to(torch.int64) +def get_inverse_transform_indices(transform_tile: callable, tile_size: Tuple[int, int]): + """ + Compute a permutation of indices that invert the specified (tiled) matrix transformation + + :param transform_tile: a function that applies forward transform to a tensor of shape [dim1, dim2] + :param tile_size: higher-level tile dimensions, i.e. (8, 32) for Turing and (32, 32) for Ampere + :note: we assume that tile_transform applies to a cpu-based int8 tensor of shape tile_size + :example: transform_tile function for the turing layout (bitsandbytes.functional as F) + :returns: indices + """ + d1, d2 = tile_size + assert 0 < d1 * d2 < 2**64 + tile_indices = torch.arange(d1 * d2, dtype=torch.int64).view(d1, d2) + # encode each position in tile as a tuple of <= 8 unique bytes + permuted_tile_indices = torch.zeros_like(tile_indices) + for i in range(8): + # select i-th byte, apply transformation and trace where each index ended up + ith_dim_indices = torch.div(tile_indices, 256**i, rounding_mode="trunc") % 256 + sample_tile_i = (ith_dim_indices - 128).to(torch.int8).contiguous() + assert torch.all(sample_tile_i.int() + 128 == ith_dim_indices), "int overflow" + permuted_tile_i = transform_tile(sample_tile_i) + ith_permuted_indices = permuted_tile_i.to(tile_indices.dtype) + 128 + permuted_tile_indices += ith_permuted_indices * (256**i) + if d1 * d2 < 256**i: + break # if all indices fit in i bytes, stop early + return permuted_tile_indices + +def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor: + """ + Undo a tiled permutation such as turing or ampere layout + + :param permuted_tensor: torch tensor in a permuted layout + :param tile_indices: reverse transformation indices, from get_inverse_transform_indices + :return: contiguous row-major tensor + """ + (rows, cols), (tile_rows, tile_cols) = permuted_tensor.shape, tile_indices.shape + assert rows % tile_rows == cols % tile_cols == 0, "tensor must contain a whole number of tiles" + tensor = permuted_tensor.reshape(-1, tile_indices.numel()).t() + outputs = torch.empty_like(tensor) # note: not using .index_copy because it was slower on cuda + outputs[tile_indices.flatten()] = tensor + outputs = outputs.reshape(tile_rows, tile_cols, cols // tile_cols, rows // tile_rows) + outputs = outputs.permute(3, 0, 2, 1) # (rows // tile_rows, tile_rows), (cols // tile_cols, tile_cols) + return outputs.reshape(rows, cols).contiguous() + + class MatMul8bit(torch.autograd.Function): @staticmethod def forward(ctx, A, B, out=None, quant_type="vector", precision=None): @@ -171,6 +223,8 @@ matmul_cublas = MatMul8bit.apply @dataclass class MatmulLtState: + tile_indices: Optional[torch.Tensor] = None + force_no_igemmlt: bool = False CB = None CxB = None SB = None @@ -202,11 +256,22 @@ class MatmulLtState: self.SBt = None self.CBt = None + def get_tile_size(self): + assert self.formatB in ( + "col_turing", + "col_ampere", + ), f"please find this assert and manually enter tile size for {self.formatB}" + return (8, 32) if self.formatB == "col_turing" else (32, 32) + class MatMul8bitLt(torch.autograd.Function): + # forward is the same, but we added the fallback for pre-turing GPUs + # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") + @staticmethod - def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): - # default to pytorch behavior if inputs are empty + def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): + using_igemmlt = torch.cuda.get_device_capability(device=A.device) >= (7, 5) and not state.force_no_igemmlt + # default of pytorch behavior if inputs are empty ctx.is_empty = False if prod(A.shape) == 0: ctx.is_empty = True @@ -214,9 +279,9 @@ class MatMul8bitLt(torch.autograd.Function): ctx.B = B ctx.bias = bias if A.shape[-1] == B.shape[0]: - return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device) + return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=A.device) else: - return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device) + return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device) # 1. Quantize A # 2. Quantize B @@ -235,9 +300,7 @@ class MatMul8bitLt(torch.autograd.Function): # 1. Quantize A if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous() - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( - A.to(torch.float16), threshold=state.threshold - ) + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold) if state.threshold > 0.0 and coo_tensorA is not None: if state.has_fp16_weights: @@ -248,12 +311,12 @@ class MatMul8bitLt(torch.autograd.Function): state.subB = B[:, idx].t().contiguous() state.idx = idx else: - if state.CxB is None: + if state.CxB is None and using_igemmlt: # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions # we also need to convert it to the turing/ampere format state.CxB, state.SB = F.transform(state.CB, to_order=formatB) else: - if not state.has_fp16_weights and state.CxB is None: + if not state.has_fp16_weights and state.CxB is None and using_igemmlt: state.CxB, state.SB = F.transform(state.CB, to_order=formatB) subA = None @@ -273,7 +336,10 @@ class MatMul8bitLt(torch.autograd.Function): state.SCBt, coo_tensorB, ) = F.double_quant(B.to(torch.float16)) - state.CxB, state.SB = F.transform(CB, to_order=formatB) + if using_igemmlt: + state.CxB, state.SB = F.transform(CB, to_order=formatB) + else: + state.CB = CB else: has_grad = False @@ -288,18 +354,17 @@ class MatMul8bitLt(torch.autograd.Function): # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) # else: # state.idx = outlier_idx - outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) - state.subB = ( - (outliers * state.SCB.view(-1, 1) / 127.0) - .t() - .contiguous() - .to(A.dtype) - ) + if state.CxB is not None: + outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) + else: + outliers = state.CB[:, state.idx.long()].clone() + + state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) CA[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0 subA = A[:, state.idx.long()] - shapeB = state.SB[0] + shapeB = state.SB[0] if state.SB else B.shape if len(input_shape) == 3: output_shape = (input_shape[0], input_shape[1], shapeB[0]) @@ -307,16 +372,25 @@ class MatMul8bitLt(torch.autograd.Function): output_shape = (input_shape[0], shapeB[0]) # 3. Matmul - C32A, SA = F.transform(CA, "col32") - out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) - # we apply the fused bias here + if using_igemmlt: + C32A, SA = F.transform(CA, "col32") + out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) + if bias is None or bias.dtype == torch.float16: + # we apply the fused bias here + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) + output = output.to(A.dtype) + else: # apply bias separately + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) + output = output.to(A.dtype).add_(bias) - if bias is None or bias.dtype == torch.float16: - output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) - output = output.to(A.dtype) - else: # apply bias separately - output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) - output = output.to(A.dtype).add_(bias) + else: + A_wo_outliers = A.clone() + if state.idx is not None: + A_wo_outliers[:, state.idx.long()] = 0 + output = torch.nn.functional.linear(A_wo_outliers, state.CB.to(A.dtype)) + output = output.mul_(state.SCB.unsqueeze(0).mul(1.0 / 127.0)) + if bias is not None: + output = output.add_(bias) # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: @@ -337,14 +411,13 @@ class MatMul8bitLt(torch.autograd.Function): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) - - clone_func = torch.clone if len(output_shape) == 3 else lambda x : x + clone_func = torch.clone if len(output_shape) == 3 else lambda x: x return clone_func(output.view(output_shape)) @staticmethod def backward(ctx, grad_output): if ctx.is_empty: - bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) + bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad CAt, subA = ctx.tensors @@ -359,9 +432,7 @@ class MatMul8bitLt(torch.autograd.Function): # Cast grad_output to fp16 if len(grad_output.shape) == 3: - grad_output = grad_output.reshape( - -1, grad_output.shape[-1] - ).contiguous() + grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) if req_gradB: @@ -376,17 +447,29 @@ class MatMul8bitLt(torch.autograd.Function): if state.CBt is not None: C32grad, Sgrad = F.transform(Cgrad, "col32") if state.CxBt is None: - state.CxBt, state.SBt = F.transform( - state.CBt, to_order=formatB, transpose=True - ) + state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True) gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) elif state.CB is not None: - CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0)) + CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) + grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) + elif state.CxB is not None: + + if state.tile_indices is None: + order, tile_size = state.formatB, state.get_tile_size() + transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device) + with torch.no_grad(): + state.tile_indices = get_inverse_transform_indices(transform, tile_size).to(state.CxB.device) + + CB = ( + undo_layout(state.CxB, state.tile_indices) + .to(ctx.dtype_A) + .mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) + ) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) else: - raise Exception('State must contain either CBt or CB matrix for backward') + raise Exception("State must contain either CBt or CB or CxB matrix for backward") return grad_A, grad_B, None, grad_bias, None diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index a623bf1..45df35e 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -209,19 +209,10 @@ class Int8Params(torch.nn.Parameter): class Linear8bitLt(nn.Linear): - def __init__( - self, - input_features, - output_features, - bias=True, - has_fp16_weights=True, - memory_efficient_backward=False, - threshold=0.0, - index=None, - ): - super().__init__( - input_features, output_features, bias - ) + def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, + memory_efficient_backward=False, threshold=0.0, index=None): + super().__init__(input_features, output_features, bias) + assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" self.state = bnb.MatmulLtState() self.index = index @@ -231,9 +222,7 @@ class Linear8bitLt(nn.Linear): if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True - self.weight = Int8Params( - self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights - ) + self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights) def init_8bit_state(self): self.state.CB = self.weight.CB @@ -241,27 +230,20 @@ class Linear8bitLt(nn.Linear): self.weight.CB = None self.weight.SCB = None - def forward(self, x): + def forward(self, x: torch.Tensor): self.state.is_training = self.training - if self.weight.CB is not None: self.init_8bit_state() # weights are cast automatically as Int8Params, but the bias has to be cast manually - if self.bias is not None and self.bias.dtype != torch.float16: - self.bias.data = self.bias.data.half() + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) - if not self.state.has_fp16_weights: - if not self.state.memory_efficient_backward and self.state.CB is not None: + if self.state.CB is not None and self.state.CxB is not None: # we converted 8-bit row major to turing/ampere format in the first inference pass # we no longer need the row-major weight del self.state.CB self.weight.data = self.state.CxB - elif self.state.memory_efficient_backward and self.state.CxB is not None: - # For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass. - # Thus, we delete CxB from the state. - del self.state.CxB - return out diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py new file mode 100644 index 0000000..4676b66 --- /dev/null +++ b/tests/test_linear8bitlt.py @@ -0,0 +1,61 @@ +import bitsandbytes as bnb +import pytest +import torch +from bitsandbytes import functional as F + +from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout +from bitsandbytes.nn.modules import Linear8bitLt + +# contributed by Alex Borzunov, see: +# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py + +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5), + reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs", +) +def test_layout_exact_match(): + x = (torch.randn(14336 * 3, 14336) * 10).to(torch.int8).cuda() + for tile_size, order in ((8, 32), "col_turing"), ((32, 32), "col_ampere"): + transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device) + tile_indices = get_inverse_transform_indices(transform, tile_size) + cxb = transform(x) + + torch.cuda.synchronize() + restored_x = undo_layout(cxb, tile_indices) + torch.cuda.synchronize() + assert restored_x.is_contiguous() + assert torch.all(torch.eq(restored_x, x)) + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +def test_linear_no_igemmlt(): + linear = torch.nn.Linear(1024, 3072) + x = torch.randn(3, 1024, dtype=torch.half) + linear_custom = Linear8bitLt( + linear.in_features, + linear.out_features, + linear.bias is not None, + has_fp16_weights=False, + threshold=6.0, + ) + linear_custom.state.force_no_igemmlt = True + + linear_custom.weight = bnb.nn.Int8Params( + linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False + ).to(linear.weight.dtype) + linear_custom.bias = linear.bias + linear = linear_custom.cuda() + linear = linear.half().cuda() + + x_ref = x.clone().cuda().requires_grad_(True) + x_ours = x.clone().cuda().requires_grad_(True) + fx_ref = linear(x_ref).float() + grad_proj = torch.randn_like(fx_ref) + (fx_ref * grad_proj).mean().backward() + + fx_ours = linear_custom(x_ours).float() + (fx_ours * grad_proj).mean().backward() + assert torch.allclose(fx_ref, fx_ours, atol=0.02) + assert torch.allclose(x_ref.grad, x_ours.grad, atol=0.01) + assert not linear_custom.state.has_fp16_weights + assert linear_custom.state.CB is not None + assert linear_custom.state.CxB is None diff --git a/tests/test_modules.py b/tests/test_modules.py index ffcf304..d78f0c9 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -382,7 +382,7 @@ names = [f"threshold_{vals}" for vals in values] @pytest.mark.parametrize("threshold", values, ids=names) -@pytest.mark.parametrize("memory_efficient_backward", [True, False]) +@pytest.mark.parametrize("memory_efficient_backward", [False]) def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): l1 = ( bnb.nn.Linear8bitLt(