From 792f6213a78d0f8aab0852cb4e01aac8f3e63def Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 4 Jan 2023 17:38:33 +0100 Subject: [PATCH 01/52] Fix for python 3.7 --- bitsandbytes/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/__main__.py b/bitsandbytes/__main__.py index ac7948b..f45fc34 100644 --- a/bitsandbytes/__main__.py +++ b/bitsandbytes/__main__.py @@ -42,7 +42,7 @@ print( ) print_header("OTHER") -print(f"{COMPILED_WITH_CUDA = }") +print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}") cuda = get_cuda_lib_handle() print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities(cuda)}") print_header("") From 59bf8fcff2389558ff8deaa136c799798e488c21 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 4 Jan 2023 17:47:18 +0100 Subject: [PATCH 02/52] fix CUDASetup call --- bitsandbytes/cuda_setup/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py index ce44d97..402b1a5 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/cuda_setup/main.py @@ -148,7 +148,7 @@ def is_cublasLt_compatible(cc): if cc is not None: cc_major, cc_minor = cc.split('.') if int(cc_major) < 7 or (int(cc_major) == 7 and int(cc_minor) < 5): - cuda_setup.add_log_entry("WARNING: Compute capability < 7.5 detected! Proceeding to load CPU-only library...", is_warning=True) + CUDASetup.get_instance().add_log_entry("WARNING: Compute capability < 7.5 detected! Proceeding to load CPU-only library...", is_warning=True) else: has_cublaslt = True return has_cublaslt From c5372a856768624127da1e7f7232edae7a70cd9b Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 5 Jan 2023 13:34:51 -0800 Subject: [PATCH 03/52] improve install instructions --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 9fa7ec9..d420e6c 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ Resources: ## TL;DR **Requirements** Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0. LLM.int8() requires Turing or Ampere GPUs. + **Installation**: ``pip install bitsandbytes`` @@ -58,6 +59,10 @@ The bitsandbytes library is currently only supported on Linux distributions. Win The requirements can best be fulfilled by installing pytorch via anaconda. You can install PyTorch by following the ["Get Started"](https://pytorch.org/get-started/locally/) instructions on the official website. +To install run: + +``pip install bitsandbytes`` + ## Using bitsandbytes ### Using Int8 Matrix Multiplication From de535889348c5406eb34d9f7e0c362cadb113be5 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 1 Feb 2023 20:09:31 -0800 Subject: [PATCH 04/52] Added Int8 matmul support for all GPUs. Full backward support. --- Makefile | 4 +- bitsandbytes/autograd/__init__.py | 1 + bitsandbytes/autograd/_functions.py | 159 +++++++++++++++++++++------- bitsandbytes/nn/modules.py | 36 ++----- tests/test_linear8bitlt.py | 61 +++++++++++ tests/test_modules.py | 2 +- 6 files changed, 195 insertions(+), 68 deletions(-) create mode 100644 tests/test_linear8bitlt.py diff --git a/Makefile b/Makefile index 10eb779..7bee7ef 100644 --- a/Makefile +++ b/Makefile @@ -60,8 +60,8 @@ CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90 all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) + $(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env diff --git a/bitsandbytes/autograd/__init__.py b/bitsandbytes/autograd/__init__.py index e69de29..6b9a7e4 100644 --- a/bitsandbytes/autograd/__init__.py +++ b/bitsandbytes/autograd/__init__.py @@ -0,0 +1 @@ +from ._functions import undo_layout, get_inverse_transform_indices diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index a115437..376fb8a 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -2,6 +2,7 @@ import operator import warnings from dataclasses import dataclass from functools import reduce # Required in Python 3 +from typing import Tuple, Optional import torch @@ -14,6 +15,12 @@ def prod(iterable): tensor = torch.Tensor + +# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov: +# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py + + + """ This class pools outlier dimensions across layers. This is particularly important for small models where outlier features @@ -48,6 +55,51 @@ class GlobalOutlierPooler: return torch.Tensor(list(self.outliers)).to(torch.int64) +def get_inverse_transform_indices(transform_tile: callable, tile_size: Tuple[int, int]): + """ + Compute a permutation of indices that invert the specified (tiled) matrix transformation + + :param transform_tile: a function that applies forward transform to a tensor of shape [dim1, dim2] + :param tile_size: higher-level tile dimensions, i.e. (8, 32) for Turing and (32, 32) for Ampere + :note: we assume that tile_transform applies to a cpu-based int8 tensor of shape tile_size + :example: transform_tile function for the turing layout (bitsandbytes.functional as F) + :returns: indices + """ + d1, d2 = tile_size + assert 0 < d1 * d2 < 2**64 + tile_indices = torch.arange(d1 * d2, dtype=torch.int64).view(d1, d2) + # encode each position in tile as a tuple of <= 8 unique bytes + permuted_tile_indices = torch.zeros_like(tile_indices) + for i in range(8): + # select i-th byte, apply transformation and trace where each index ended up + ith_dim_indices = torch.div(tile_indices, 256**i, rounding_mode="trunc") % 256 + sample_tile_i = (ith_dim_indices - 128).to(torch.int8).contiguous() + assert torch.all(sample_tile_i.int() + 128 == ith_dim_indices), "int overflow" + permuted_tile_i = transform_tile(sample_tile_i) + ith_permuted_indices = permuted_tile_i.to(tile_indices.dtype) + 128 + permuted_tile_indices += ith_permuted_indices * (256**i) + if d1 * d2 < 256**i: + break # if all indices fit in i bytes, stop early + return permuted_tile_indices + +def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor: + """ + Undo a tiled permutation such as turing or ampere layout + + :param permuted_tensor: torch tensor in a permuted layout + :param tile_indices: reverse transformation indices, from get_inverse_transform_indices + :return: contiguous row-major tensor + """ + (rows, cols), (tile_rows, tile_cols) = permuted_tensor.shape, tile_indices.shape + assert rows % tile_rows == cols % tile_cols == 0, "tensor must contain a whole number of tiles" + tensor = permuted_tensor.reshape(-1, tile_indices.numel()).t() + outputs = torch.empty_like(tensor) # note: not using .index_copy because it was slower on cuda + outputs[tile_indices.flatten()] = tensor + outputs = outputs.reshape(tile_rows, tile_cols, cols // tile_cols, rows // tile_rows) + outputs = outputs.permute(3, 0, 2, 1) # (rows // tile_rows, tile_rows), (cols // tile_cols, tile_cols) + return outputs.reshape(rows, cols).contiguous() + + class MatMul8bit(torch.autograd.Function): @staticmethod def forward(ctx, A, B, out=None, quant_type="vector", precision=None): @@ -171,6 +223,8 @@ matmul_cublas = MatMul8bit.apply @dataclass class MatmulLtState: + tile_indices: Optional[torch.Tensor] = None + force_no_igemmlt: bool = False CB = None CxB = None SB = None @@ -202,11 +256,22 @@ class MatmulLtState: self.SBt = None self.CBt = None + def get_tile_size(self): + assert self.formatB in ( + "col_turing", + "col_ampere", + ), f"please find this assert and manually enter tile size for {self.formatB}" + return (8, 32) if self.formatB == "col_turing" else (32, 32) + class MatMul8bitLt(torch.autograd.Function): + # forward is the same, but we added the fallback for pre-turing GPUs + # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") + @staticmethod - def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): - # default to pytorch behavior if inputs are empty + def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): + using_igemmlt = torch.cuda.get_device_capability(device=A.device) >= (7, 5) and not state.force_no_igemmlt + # default of pytorch behavior if inputs are empty ctx.is_empty = False if prod(A.shape) == 0: ctx.is_empty = True @@ -214,9 +279,9 @@ class MatMul8bitLt(torch.autograd.Function): ctx.B = B ctx.bias = bias if A.shape[-1] == B.shape[0]: - return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device) + return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=A.device) else: - return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device) + return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device) # 1. Quantize A # 2. Quantize B @@ -235,9 +300,7 @@ class MatMul8bitLt(torch.autograd.Function): # 1. Quantize A if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous() - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( - A.to(torch.float16), threshold=state.threshold - ) + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold) if state.threshold > 0.0 and coo_tensorA is not None: if state.has_fp16_weights: @@ -248,12 +311,12 @@ class MatMul8bitLt(torch.autograd.Function): state.subB = B[:, idx].t().contiguous() state.idx = idx else: - if state.CxB is None: + if state.CxB is None and using_igemmlt: # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions # we also need to convert it to the turing/ampere format state.CxB, state.SB = F.transform(state.CB, to_order=formatB) else: - if not state.has_fp16_weights and state.CxB is None: + if not state.has_fp16_weights and state.CxB is None and using_igemmlt: state.CxB, state.SB = F.transform(state.CB, to_order=formatB) subA = None @@ -273,7 +336,10 @@ class MatMul8bitLt(torch.autograd.Function): state.SCBt, coo_tensorB, ) = F.double_quant(B.to(torch.float16)) - state.CxB, state.SB = F.transform(CB, to_order=formatB) + if using_igemmlt: + state.CxB, state.SB = F.transform(CB, to_order=formatB) + else: + state.CB = CB else: has_grad = False @@ -288,18 +354,17 @@ class MatMul8bitLt(torch.autograd.Function): # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) # else: # state.idx = outlier_idx - outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) - state.subB = ( - (outliers * state.SCB.view(-1, 1) / 127.0) - .t() - .contiguous() - .to(A.dtype) - ) + if state.CxB is not None: + outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) + else: + outliers = state.CB[:, state.idx.long()].clone() + + state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) CA[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0 subA = A[:, state.idx.long()] - shapeB = state.SB[0] + shapeB = state.SB[0] if state.SB else B.shape if len(input_shape) == 3: output_shape = (input_shape[0], input_shape[1], shapeB[0]) @@ -307,16 +372,25 @@ class MatMul8bitLt(torch.autograd.Function): output_shape = (input_shape[0], shapeB[0]) # 3. Matmul - C32A, SA = F.transform(CA, "col32") - out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) - # we apply the fused bias here + if using_igemmlt: + C32A, SA = F.transform(CA, "col32") + out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) + if bias is None or bias.dtype == torch.float16: + # we apply the fused bias here + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) + output = output.to(A.dtype) + else: # apply bias separately + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) + output = output.to(A.dtype).add_(bias) - if bias is None or bias.dtype == torch.float16: - output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) - output = output.to(A.dtype) - else: # apply bias separately - output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) - output = output.to(A.dtype).add_(bias) + else: + A_wo_outliers = A.clone() + if state.idx is not None: + A_wo_outliers[:, state.idx.long()] = 0 + output = torch.nn.functional.linear(A_wo_outliers, state.CB.to(A.dtype)) + output = output.mul_(state.SCB.unsqueeze(0).mul(1.0 / 127.0)) + if bias is not None: + output = output.add_(bias) # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: @@ -337,14 +411,13 @@ class MatMul8bitLt(torch.autograd.Function): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) - - clone_func = torch.clone if len(output_shape) == 3 else lambda x : x + clone_func = torch.clone if len(output_shape) == 3 else lambda x: x return clone_func(output.view(output_shape)) @staticmethod def backward(ctx, grad_output): if ctx.is_empty: - bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) + bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad CAt, subA = ctx.tensors @@ -359,9 +432,7 @@ class MatMul8bitLt(torch.autograd.Function): # Cast grad_output to fp16 if len(grad_output.shape) == 3: - grad_output = grad_output.reshape( - -1, grad_output.shape[-1] - ).contiguous() + grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) if req_gradB: @@ -376,17 +447,29 @@ class MatMul8bitLt(torch.autograd.Function): if state.CBt is not None: C32grad, Sgrad = F.transform(Cgrad, "col32") if state.CxBt is None: - state.CxBt, state.SBt = F.transform( - state.CBt, to_order=formatB, transpose=True - ) + state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True) gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) elif state.CB is not None: - CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0)) + CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) + grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) + elif state.CxB is not None: + + if state.tile_indices is None: + order, tile_size = state.formatB, state.get_tile_size() + transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device) + with torch.no_grad(): + state.tile_indices = get_inverse_transform_indices(transform, tile_size).to(state.CxB.device) + + CB = ( + undo_layout(state.CxB, state.tile_indices) + .to(ctx.dtype_A) + .mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) + ) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) else: - raise Exception('State must contain either CBt or CB matrix for backward') + raise Exception("State must contain either CBt or CB or CxB matrix for backward") return grad_A, grad_B, None, grad_bias, None diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index a623bf1..45df35e 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -209,19 +209,10 @@ class Int8Params(torch.nn.Parameter): class Linear8bitLt(nn.Linear): - def __init__( - self, - input_features, - output_features, - bias=True, - has_fp16_weights=True, - memory_efficient_backward=False, - threshold=0.0, - index=None, - ): - super().__init__( - input_features, output_features, bias - ) + def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, + memory_efficient_backward=False, threshold=0.0, index=None): + super().__init__(input_features, output_features, bias) + assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" self.state = bnb.MatmulLtState() self.index = index @@ -231,9 +222,7 @@ class Linear8bitLt(nn.Linear): if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True - self.weight = Int8Params( - self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights - ) + self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights) def init_8bit_state(self): self.state.CB = self.weight.CB @@ -241,27 +230,20 @@ class Linear8bitLt(nn.Linear): self.weight.CB = None self.weight.SCB = None - def forward(self, x): + def forward(self, x: torch.Tensor): self.state.is_training = self.training - if self.weight.CB is not None: self.init_8bit_state() # weights are cast automatically as Int8Params, but the bias has to be cast manually - if self.bias is not None and self.bias.dtype != torch.float16: - self.bias.data = self.bias.data.half() + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) - if not self.state.has_fp16_weights: - if not self.state.memory_efficient_backward and self.state.CB is not None: + if self.state.CB is not None and self.state.CxB is not None: # we converted 8-bit row major to turing/ampere format in the first inference pass # we no longer need the row-major weight del self.state.CB self.weight.data = self.state.CxB - elif self.state.memory_efficient_backward and self.state.CxB is not None: - # For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass. - # Thus, we delete CxB from the state. - del self.state.CxB - return out diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py new file mode 100644 index 0000000..4676b66 --- /dev/null +++ b/tests/test_linear8bitlt.py @@ -0,0 +1,61 @@ +import bitsandbytes as bnb +import pytest +import torch +from bitsandbytes import functional as F + +from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout +from bitsandbytes.nn.modules import Linear8bitLt + +# contributed by Alex Borzunov, see: +# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py + +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5), + reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs", +) +def test_layout_exact_match(): + x = (torch.randn(14336 * 3, 14336) * 10).to(torch.int8).cuda() + for tile_size, order in ((8, 32), "col_turing"), ((32, 32), "col_ampere"): + transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device) + tile_indices = get_inverse_transform_indices(transform, tile_size) + cxb = transform(x) + + torch.cuda.synchronize() + restored_x = undo_layout(cxb, tile_indices) + torch.cuda.synchronize() + assert restored_x.is_contiguous() + assert torch.all(torch.eq(restored_x, x)) + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +def test_linear_no_igemmlt(): + linear = torch.nn.Linear(1024, 3072) + x = torch.randn(3, 1024, dtype=torch.half) + linear_custom = Linear8bitLt( + linear.in_features, + linear.out_features, + linear.bias is not None, + has_fp16_weights=False, + threshold=6.0, + ) + linear_custom.state.force_no_igemmlt = True + + linear_custom.weight = bnb.nn.Int8Params( + linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False + ).to(linear.weight.dtype) + linear_custom.bias = linear.bias + linear = linear_custom.cuda() + linear = linear.half().cuda() + + x_ref = x.clone().cuda().requires_grad_(True) + x_ours = x.clone().cuda().requires_grad_(True) + fx_ref = linear(x_ref).float() + grad_proj = torch.randn_like(fx_ref) + (fx_ref * grad_proj).mean().backward() + + fx_ours = linear_custom(x_ours).float() + (fx_ours * grad_proj).mean().backward() + assert torch.allclose(fx_ref, fx_ours, atol=0.02) + assert torch.allclose(x_ref.grad, x_ours.grad, atol=0.01) + assert not linear_custom.state.has_fp16_weights + assert linear_custom.state.CB is not None + assert linear_custom.state.CxB is None diff --git a/tests/test_modules.py b/tests/test_modules.py index ffcf304..d78f0c9 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -382,7 +382,7 @@ names = [f"threshold_{vals}" for vals in values] @pytest.mark.parametrize("threshold", values, ids=names) -@pytest.mark.parametrize("memory_efficient_backward", [True, False]) +@pytest.mark.parametrize("memory_efficient_backward", [False]) def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): l1 = ( bnb.nn.Linear8bitLt( From 0f5c3948709ae70cf733cefbd831aaea8a4e38c9 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 1 Feb 2023 20:27:01 -0800 Subject: [PATCH 05/52] Added version 0.37.0. --- CHANGELOG.md | 12 ++++++++++++ bitsandbytes/cuda_setup/main.py | 13 +++++++------ setup.py | 2 +- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 77703a0..ac239de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -189,3 +189,15 @@ Improvements: - StableEmbedding layer now has device and dtype parameters to make it 1:1 replaceable with regular Embedding layers (@lostmsu) - runtime performance of block-wise quantization slightly improved - added error message for the case multiple libcudart.so are installed and bitsandbytes picks the wrong one + + +### 0.37.0 + +#### Int8 Matmul + backward support for all GPUs + +Features: + - Int8 MatmulLt now supports backward through inversion of the ColTuring/ColAmpere format. Slow, but memory efficient. Big thanks to @borzunov + - Int8 now supported on all GPUs. On devices with compute capability < 7.5, the Int weights are cast to 16/32-bit for the matrix multiplication. Contributed by @borzunov + +Improvements: + - Improved logging for the CUDA detection mechanism. diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py index ce44d97..cd9573f 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/cuda_setup/main.py @@ -80,9 +80,10 @@ class CUDASetup: self.add_log_entry('python setup.py install') def initialize(self): - self.has_printed = False - self.lib = None - self.initialized = False + if not getattr(self, 'initialized', False): + self.has_printed = False + self.lib = None + self.initialized = False def run_cuda_setup(self): self.initialized = True @@ -103,7 +104,7 @@ class CUDASetup: legacy_binary_name = "libbitsandbytes_cpu.so" self.add_log_entry(f"CUDA SETUP: Defaulting to {legacy_binary_name}...") binary_path = package_dir / legacy_binary_name - if not binary_path.exists(): + if not binary_path.exists() or torch.cuda.is_available(): self.add_log_entry('') self.add_log_entry('='*48 + 'ERROR' + '='*37) self.add_log_entry('CUDA SETUP: CUDA detection failed! Possible reasons:') @@ -112,6 +113,7 @@ class CUDASetup: self.add_log_entry('3. You have multiple conflicting CUDA libraries') self.add_log_entry('4. Required library not pre-compiled for this bitsandbytes release!') self.add_log_entry('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.') + self.add_log_entry('CUDA SETUP: The CUDA version for the compile might depend on your conda install. Inspect CUDA version via `conda list | grep cuda`.') self.add_log_entry('='*80) self.add_log_entry('') self.generate_instructions() @@ -148,7 +150,7 @@ def is_cublasLt_compatible(cc): if cc is not None: cc_major, cc_minor = cc.split('.') if int(cc_major) < 7 or (int(cc_major) == 7 and int(cc_minor) < 5): - cuda_setup.add_log_entry("WARNING: Compute capability < 7.5 detected! Proceeding to load CPU-only library...", is_warning=True) + cuda_setup.add_log_entry("WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU!", is_warning=True) else: has_cublaslt = True return has_cublaslt @@ -362,7 +364,6 @@ def evaluate_cuda_setup(): print('') print('='*35 + 'BUG REPORT' + '='*35) print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues') - print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link') print('='*80) if not torch.cuda.is_available(): return 'libsbitsandbytes_cpu.so', None, None, None, None diff --git a/setup.py b/setup.py index 93df40e..e3f453e 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ def read(fname): setup( name=f"bitsandbytes", - version=f"0.36.0-2", + version=f"0.37.0", author="Tim Dettmers", author_email="dettmers@cs.washington.edu", description="8-bit optimizers and matrix multiplication routines.", From 58b09ee1b11c26344aaeee7ff648655f1d127202 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Tue, 21 Feb 2023 12:04:47 +0100 Subject: [PATCH 06/52] [WIP] Implement proper serialization of Linear8bitLt --- bitsandbytes/nn/modules.py | 28 +++++++++++++++++++ tests/test_linear8bitlt.py | 56 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 81 insertions(+), 3 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 45df35e..754ba20 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -224,6 +224,34 @@ class Linear8bitLt(nn.Linear): self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights) + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + + # we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data + weight_name = "SCB" + + # case 1: .cuda was called, SCB is in self.weight + param_from_weight = getattr(self.weight, weight_name) + # case 2: self.init_8bit_state was called, SCB is in self.state + param_from_state = getattr(self.state, weight_name) + + key_name = prefix + f"{weight_name}" + if param_from_weight is not None: + destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach() + elif not self.state.has_fp16_weights and param_from_state is not None: + destination[key_name] = param_from_state if keep_vars else param_from_state.detach() + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs) + for key in unexpected_keys: + input_name = key[len(prefix):] + if input_name == "SCB": + input_param = state_dict[key] + self.weight.SCB.copy_(input_param) + unexpected_keys.remove(key) + def init_8bit_state(self): self.state.CB = self.weight.CB self.state.SCB = self.weight.SCB diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 4676b66..8edee58 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -1,11 +1,14 @@ -import bitsandbytes as bnb +from copy import deepcopy + import pytest import torch -from bitsandbytes import functional as F +import bitsandbytes as bnb +from bitsandbytes import functional as F from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout from bitsandbytes.nn.modules import Linear8bitLt + # contributed by Alex Borzunov, see: # https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py @@ -26,6 +29,7 @@ def test_layout_exact_match(): assert restored_x.is_contiguous() assert torch.all(torch.eq(restored_x, x)) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") def test_linear_no_igemmlt(): linear = torch.nn.Linear(1024, 3072) @@ -43,7 +47,7 @@ def test_linear_no_igemmlt(): linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False ).to(linear.weight.dtype) linear_custom.bias = linear.bias - linear = linear_custom.cuda() + linear_custom = linear_custom.cuda() linear = linear.half().cuda() x_ref = x.clone().cuda().requires_grad_(True) @@ -59,3 +63,49 @@ def test_linear_no_igemmlt(): assert not linear_custom.state.has_fp16_weights assert linear_custom.state.CB is not None assert linear_custom.state.CxB is None + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +@pytest.mark.parametrize("has_fp16_weights", [False, True]) +def test_linear_serialization(has_fp16_weights): + linear = torch.nn.Linear(16, 32) + x = torch.randn(3, 16, dtype=torch.half) + + linear_custom = Linear8bitLt( + linear.in_features, + linear.out_features, + linear.bias is not None, + has_fp16_weights=has_fp16_weights, + threshold=6.0, + ) + linear_custom.state.force_no_igemmlt = True + linear_custom.weight = bnb.nn.Int8Params( + linear.weight.data.clone(), requires_grad=False, has_fp16_weights=has_fp16_weights + ).to(linear.weight.dtype) + linear_custom.bias = linear.bias + linear_custom = linear_custom.cuda() + + x_first = x.clone().cuda().requires_grad_(True) + fx_first = linear_custom(x_first).float() + grad_proj = torch.randn_like(fx_first) + (fx_first * grad_proj).mean().backward() + + state_dict = deepcopy(linear_custom.state_dict()) + + new_linear_custom = Linear8bitLt( + linear.in_features, + linear.out_features, + linear.bias is not None, + has_fp16_weights=has_fp16_weights, + threshold=6.0, + ) + linear_custom.state.force_no_igemmlt = True + new_linear_custom = new_linear_custom.cuda() + new_linear_custom.load_state_dict(state_dict, strict=True) + + x_second = x.clone().cuda().requires_grad_(True) + fx_second = new_linear_custom(x_second).float() + (fx_second * grad_proj).mean().backward() + + assert torch.allclose(fx_first, fx_second, atol=1e-5) + assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5) From ac3ab281e39cbc514ebef08823482d5b0cba42c1 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sat, 25 Feb 2023 06:01:04 +0100 Subject: [PATCH 07/52] Handle more cases in test_linear_serialization --- tests/test_linear8bitlt.py | 53 ++++++++++++++++++++++++++++---------- 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 8edee58..1aafe3d 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -1,4 +1,7 @@ -from copy import deepcopy +import os +from contextlib import nullcontext +from itertools import product +from tempfile import TemporaryDirectory import pytest import torch @@ -66,10 +69,11 @@ def test_linear_no_igemmlt(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") -@pytest.mark.parametrize("has_fp16_weights", [False, True]) -def test_linear_serialization(has_fp16_weights): - linear = torch.nn.Linear(16, 32) - x = torch.randn(3, 16, dtype=torch.half) +@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda", + list(product([False, True], [False, True], [False, True]))) +def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda): + linear = torch.nn.Linear(32, 96) + x = torch.randn(3, 32, dtype=torch.half) linear_custom = Linear8bitLt( linear.in_features, @@ -78,19 +82,34 @@ def test_linear_serialization(has_fp16_weights): has_fp16_weights=has_fp16_weights, threshold=6.0, ) - linear_custom.state.force_no_igemmlt = True linear_custom.weight = bnb.nn.Int8Params( - linear.weight.data.clone(), requires_grad=False, has_fp16_weights=has_fp16_weights - ).to(linear.weight.dtype) + linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights + ) linear_custom.bias = linear.bias linear_custom = linear_custom.cuda() + if serialize_before_forward: + state_dict_8bit = linear_custom.state_dict() + x_first = x.clone().cuda().requires_grad_(True) fx_first = linear_custom(x_first).float() grad_proj = torch.randn_like(fx_first) (fx_first * grad_proj).mean().backward() - state_dict = deepcopy(linear_custom.state_dict()) + if not serialize_before_forward: + state_dict_8bit = linear_custom.state_dict() + + with TemporaryDirectory() as tmpdir: + state_path_8bit = os.path.join(tmpdir, "state_8bit.pth") + state_path = os.path.join(tmpdir, "state.pth") + + torch.save(linear.state_dict(), state_path) + torch.save(state_dict_8bit, state_path_8bit) + + if not has_fp16_weights: + assert os.path.getsize(state_path_8bit) < 0.5 * os.path.getsize(state_path) + + new_state_dict = torch.load(state_path_8bit) new_linear_custom = Linear8bitLt( linear.in_features, @@ -99,13 +118,21 @@ def test_linear_serialization(has_fp16_weights): has_fp16_weights=has_fp16_weights, threshold=6.0, ) - linear_custom.state.force_no_igemmlt = True + + if deserialize_before_cuda: + with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError): + new_linear_custom.load_state_dict(new_state_dict, strict=True) + new_linear_custom = new_linear_custom.cuda() - new_linear_custom.load_state_dict(state_dict, strict=True) + + if not deserialize_before_cuda: + new_linear_custom.load_state_dict(new_state_dict, strict=True) x_second = x.clone().cuda().requires_grad_(True) fx_second = new_linear_custom(x_second).float() (fx_second * grad_proj).mean().backward() - assert torch.allclose(fx_first, fx_second, atol=1e-5) - assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5) + # if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised + if has_fp16_weights or not deserialize_before_cuda: + assert torch.allclose(fx_first, fx_second, atol=1e-5) + assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5) From cd4d904a4ccc80c444e460d3aef20705895d2051 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sat, 25 Feb 2023 06:01:34 +0100 Subject: [PATCH 08/52] Raise an error when loading a quantized checkpoint before quantization --- bitsandbytes/nn/modules.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 754ba20..65b2102 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -248,6 +248,11 @@ class Linear8bitLt(nn.Linear): for key in unexpected_keys: input_name = key[len(prefix):] if input_name == "SCB": + if self.weight.SCB is None: + # buffers not yet initialized, can't call them directly without + raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is " + "not supported. Please call module.cuda() before module.load_state_dict()") + input_param = state_dict[key] self.weight.SCB.copy_(input_param) unexpected_keys.remove(key) From cc608c04c292c906ff48223e386689c1c024f601 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sat, 25 Feb 2023 06:02:06 +0100 Subject: [PATCH 09/52] Revert the layout if weights were reordered --- bitsandbytes/nn/modules.py | 55 +++++++++++++++++++++++++++++--------- 1 file changed, 42 insertions(+), 13 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 65b2102..bd39856 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -9,6 +9,8 @@ import torch.nn.functional as F from torch import Tensor, device, dtype, nn import bitsandbytes as bnb +import bitsandbytes.functional +from bitsandbytes.autograd._functions import get_inverse_transform_indices, undo_layout from bitsandbytes.optim import GlobalOptimManager T = TypeVar("T", bound="torch.nn.Module") @@ -210,7 +212,7 @@ class Int8Params(torch.nn.Parameter): class Linear8bitLt(nn.Linear): def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, - memory_efficient_backward=False, threshold=0.0, index=None): + memory_efficient_backward=False, threshold=0.0, index=None): super().__init__(input_features, output_features, bias) assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" self.state = bnb.MatmulLtState() @@ -225,21 +227,48 @@ class Linear8bitLt(nn.Linear): self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights) def _save_to_state_dict(self, destination, prefix, keep_vars): - super()._save_to_state_dict(destination, prefix, keep_vars) + if not self.state.has_fp16_weights and self.state.CB is None and self.state.CxB is not None: + # reorder weight layout back from ampere/turing to row + reorder_layout = True + weight_clone = self.weight.data.clone() + else: + reorder_layout = False - # we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data - weight_name = "SCB" + try: + if reorder_layout: + if self.state.tile_indices is None: + order, tile_size = self.state.formatB, self.state.get_tile_size() + transform = lambda x: \ + bitsandbytes.functional.transform(x.to(self.weight.data.device), from_order="row", + to_order=order)[0].to(x.device) + with torch.no_grad(): + self.state.tile_indices = get_inverse_transform_indices(transform, tile_size).to( + self.state.CxB.device) - # case 1: .cuda was called, SCB is in self.weight - param_from_weight = getattr(self.weight, weight_name) - # case 2: self.init_8bit_state was called, SCB is in self.state - param_from_state = getattr(self.state, weight_name) + CB = ( + undo_layout(self.state.CxB, self.state.tile_indices) + ) - key_name = prefix + f"{weight_name}" - if param_from_weight is not None: - destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach() - elif not self.state.has_fp16_weights and param_from_state is not None: - destination[key_name] = param_from_state if keep_vars else param_from_state.detach() + self.weight.data = CB + + super()._save_to_state_dict(destination, prefix, keep_vars) + + # we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data + weight_name = "SCB" + + # case 1: .cuda was called, SCB is in self.weight + param_from_weight = getattr(self.weight, weight_name) + # case 2: self.init_8bit_state was called, SCB is in self.state + param_from_state = getattr(self.state, weight_name) + + key_name = prefix + f"{weight_name}" + if param_from_weight is not None: + destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach() + elif not self.state.has_fp16_weights and param_from_state is not None: + destination[key_name] = param_from_state if keep_vars else param_from_state.detach() + finally: + if reorder_layout: + self.weight.data = weight_clone def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): From d15822a54b5a41e7b35c233616eb77cea337a06c Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sat, 25 Feb 2023 06:23:07 +0100 Subject: [PATCH 10/52] Refactor _tile_indices into a cached property, fix device bug --- bitsandbytes/autograd/_functions.py | 18 ++++++++++-------- bitsandbytes/nn/modules.py | 15 +-------------- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 376fb8a..988ac87 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -223,7 +223,7 @@ matmul_cublas = MatMul8bit.apply @dataclass class MatmulLtState: - tile_indices: Optional[torch.Tensor] = None + _tile_indices: Optional[torch.Tensor] = None force_no_igemmlt: bool = False CB = None CxB = None @@ -263,6 +263,15 @@ class MatmulLtState: ), f"please find this assert and manually enter tile size for {self.formatB}" return (8, 32) if self.formatB == "col_turing" else (32, 32) + @property + def tile_indices(self): + if self._tile_indices is None: + device = self.CxB.device + transform = lambda x: F.transform(x.to(device), from_order="row", to_order=self.formatB)[0].to(x.device) + with torch.no_grad(): + self._tile_indices = get_inverse_transform_indices(transform, self.get_tile_size()).to(device) + return self._tile_indices + class MatMul8bitLt(torch.autograd.Function): # forward is the same, but we added the fallback for pre-turing GPUs @@ -455,13 +464,6 @@ class MatMul8bitLt(torch.autograd.Function): CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) elif state.CxB is not None: - - if state.tile_indices is None: - order, tile_size = state.formatB, state.get_tile_size() - transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device) - with torch.no_grad(): - state.tile_indices = get_inverse_transform_indices(transform, tile_size).to(state.CxB.device) - CB = ( undo_layout(state.CxB, state.tile_indices) .to(ctx.dtype_A) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index bd39856..31d8fd7 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -236,20 +236,7 @@ class Linear8bitLt(nn.Linear): try: if reorder_layout: - if self.state.tile_indices is None: - order, tile_size = self.state.formatB, self.state.get_tile_size() - transform = lambda x: \ - bitsandbytes.functional.transform(x.to(self.weight.data.device), from_order="row", - to_order=order)[0].to(x.device) - with torch.no_grad(): - self.state.tile_indices = get_inverse_transform_indices(transform, tile_size).to( - self.state.CxB.device) - - CB = ( - undo_layout(self.state.CxB, self.state.tile_indices) - ) - - self.weight.data = CB + self.weight.data = undo_layout(self.state.CxB, self.state.tile_indices) super()._save_to_state_dict(destination, prefix, keep_vars) From 24609b66af17c91fa691457dabf45674e81c73a9 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sat, 25 Feb 2023 06:24:58 +0100 Subject: [PATCH 11/52] Reduce diff --- bitsandbytes/nn/modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 31d8fd7..8c4d688 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -212,7 +212,7 @@ class Int8Params(torch.nn.Parameter): class Linear8bitLt(nn.Linear): def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, - memory_efficient_backward=False, threshold=0.0, index=None): + memory_efficient_backward=False, threshold=0.0, index=None): super().__init__(input_features, output_features, bias) assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" self.state = bnb.MatmulLtState() From dba11b0b2e71645e764768cf4d404adae5582661 Mon Sep 17 00:00:00 2001 From: ubik2 Date: Mon, 6 Mar 2023 16:57:57 -0800 Subject: [PATCH 12/52] Update compile_from_source.md Add cuda12x to the list of targets --- compile_from_source.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compile_from_source.md b/compile_from_source.md index 2c4a6ad..c126341 100644 --- a/compile_from_source.md +++ b/compile_from_source.md @@ -1,7 +1,7 @@ # Compiling from source Basic steps. -1. `make [target]` where `[target]` is among `cuda92, cuda10x, cuda110, cuda11x, cpuonly` +1. `make [target]` where `[target]` is among `cuda92, cuda10x, cuda110, cuda11x, cuda12x, cpuonly` 2. `CUDA_VERSION=XXX python setup.py install` To run these steps you will need to have the nvcc compiler installed that comes with a CUDA installation. If you use anaconda (recommended) then you can figure out which version of CUDA you are using with PyTorch via the command `conda list | grep cudatoolkit`. Then you can install the nvcc compiler by downloading and installing the same CUDA version from the [CUDA toolkit archive](https://developer.nvidia.com/cuda-toolkit-archive). From 7247cb4554b5024afbc69efcbb6e92a53c77728f Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 9 Mar 2023 08:08:46 -0800 Subject: [PATCH 13/52] initial commit, slowly work from interface into the kernel --- bitsandbytes/optim/__init__.py | 1 + bitsandbytes/optim/lion.py | 115 +++++++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+) create mode 100644 bitsandbytes/optim/lion.py diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py index 8c8a8f4..53533ee 100644 --- a/bitsandbytes/optim/__init__.py +++ b/bitsandbytes/optim/__init__.py @@ -12,4 +12,5 @@ from .lamb import LAMB, LAMB8bit, LAMB32bit from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS from .optimizer import GlobalOptimManager from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit +from .lion import Lion, Lion8bit, Lion32bit from .sgd import SGD, SGD8bit, SGD32bit diff --git a/bitsandbytes/optim/lion.py b/bitsandbytes/optim/lion.py new file mode 100644 index 0000000..cd2a9da --- /dev/null +++ b/bitsandbytes/optim/lion.py @@ -0,0 +1,115 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from bitsandbytes.optim.optimizer import Optimizer1State + + +class Lion(Optimizer1State): + def __init__( + self, + params, + lr=1e-2, + alpha=0.99, + eps=1e-8, + weight_decay=0, + momentum=0, + centered=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + if alpha == 0: + raise NotImplementedError( + "RMSprop with alpha==0.0 is not supported!" + ) + if centered: + raise NotImplementedError("Centered RMSprop is not supported!") + super().__init__( + "rmsprop", + params, + lr, + (alpha, momentum), + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + + +class Lion8bit(Optimizer1State): + def __init__( + self, + params, + lr=1e-2, + alpha=0.99, + eps=1e-8, + weight_decay=0, + momentum=0, + centered=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + if alpha == 0: + raise NotImplementedError( + "RMSprop with alpha==0.0 is not supported!" + ) + if centered: + raise NotImplementedError("Centered RMSprop is not supported!") + super().__init__( + "rmsprop", + params, + lr, + (alpha, momentum), + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + + +class Lion32bit(Optimizer1State): + def __init__( + self, + params, + lr=1e-2, + alpha=0.99, + eps=1e-8, + weight_decay=0, + momentum=0, + centered=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + + if alpha == 0: + raise NotImplementedError( + "RMSprop with alpha==0.0 is not supported!" + ) + if centered: + raise NotImplementedError("Centered RMSprop is not supported!") + super().__init__( + "rmsprop", + params, + lr, + (alpha, momentum), + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) From d43ea9722c3ff754141a3d2844fb09363a89911f Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 9 Mar 2023 09:45:33 -0800 Subject: [PATCH 14/52] make sure interface is correct --- bitsandbytes/optim/lion.py | 52 +++++++++----------------------------- 1 file changed, 12 insertions(+), 40 deletions(-) diff --git a/bitsandbytes/optim/lion.py b/bitsandbytes/optim/lion.py index cd2a9da..a2fb6af 100644 --- a/bitsandbytes/optim/lion.py +++ b/bitsandbytes/optim/lion.py @@ -9,30 +9,21 @@ class Lion(Optimizer1State): def __init__( self, params, - lr=1e-2, - alpha=0.99, - eps=1e-8, + lr=1e-4, + betas=(0.9, 0.99), weight_decay=0, - momentum=0, - centered=False, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, ): - if alpha == 0: - raise NotImplementedError( - "RMSprop with alpha==0.0 is not supported!" - ) - if centered: - raise NotImplementedError("Centered RMSprop is not supported!") super().__init__( "rmsprop", params, lr, - (alpha, momentum), - eps, + betas, + 0., weight_decay, optim_bits, args, @@ -46,29 +37,20 @@ class Lion8bit(Optimizer1State): def __init__( self, params, - lr=1e-2, - alpha=0.99, - eps=1e-8, + lr=1e-4, + betas=(0.9, 0.99), weight_decay=0, - momentum=0, - centered=False, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, ): - if alpha == 0: - raise NotImplementedError( - "RMSprop with alpha==0.0 is not supported!" - ) - if centered: - raise NotImplementedError("Centered RMSprop is not supported!") super().__init__( "rmsprop", params, lr, - (alpha, momentum), - eps, + betas, + 0., weight_decay, 8, args, @@ -82,30 +64,20 @@ class Lion32bit(Optimizer1State): def __init__( self, params, - lr=1e-2, - alpha=0.99, - eps=1e-8, + lr=1e-4, + betas=(0.9, 0.99), weight_decay=0, - momentum=0, - centered=False, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, ): - - if alpha == 0: - raise NotImplementedError( - "RMSprop with alpha==0.0 is not supported!" - ) - if centered: - raise NotImplementedError("Centered RMSprop is not supported!") super().__init__( "rmsprop", params, lr, - (alpha, momentum), - eps, + betas, + 0., weight_decay, 32, args, From cb4c3c8c66405caca780ccf3d4af3d8cc581f9fd Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 9 Mar 2023 10:10:19 -0800 Subject: [PATCH 15/52] do a bunch of typical bookkeeping before getting to main lion logic --- README.md | 2 +- bitsandbytes/functional.py | 12 ++++++++++++ bitsandbytes/optim/lion.py | 6 +++--- csrc/kernels.cu | 16 ++++++++++++++++ csrc/ops.cu | 9 +++++++++ csrc/ops.cuh | 1 + csrc/pythonInterface.c | 12 ++++++++++++ tests/test_optim.py | 4 ++++ 8 files changed, 58 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index d420e6c..dfd91cd 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ out = linear(x.to(torch.float16)) ## Features - 8-bit Matrix multiplication with mixed precision decomposition - LLM.int8() inference -- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB (saves 75% memory) +- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB, Lion (saves 75% memory) - Stable Embedding Layer: Improved stability through better initialization, and normalization - 8-bit quantization: Quantile, Linear, and Dynamic quantization - Fast quantile estimation: Up to 100x faster than other algorithms diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 95a7c4f..166e38f 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -35,6 +35,10 @@ if COMPILED_WITH_CUDA: lib.crmsprop32bit_g32, lib.crmsprop32bit_g16, ) + str2optimizer32bit["lion"] = ( + lib.clion32bit_g32, + lib.clion32bit_g16, + ) str2optimizer32bit["adagrad"] = ( lib.cadagrad32bit_g32, lib.cadagrad32bit_g16, @@ -58,6 +62,10 @@ if COMPILED_WITH_CUDA: lib.crmsprop_static_8bit_g32, lib.crmsprop_static_8bit_g16, ) + str2optimizer8bit["lion"] = ( + lib.clion_static_8bit_g32, + lib.clion_static_8bit_g16, + ) str2optimizer8bit["lamb"] = ( lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16, @@ -80,6 +88,10 @@ if COMPILED_WITH_CUDA: lib.crmsprop_8bit_blockwise_fp32, lib.crmsprop_8bit_blockwise_fp16, ) + str2optimizer8bit_blockwise["lion"] = ( + lib.clion_8bit_blockwise_fp32, + lib.clion_8bit_blockwise_fp16, + ) str2optimizer8bit_blockwise["adagrad"] = ( lib.cadagrad_8bit_blockwise_fp32, lib.cadagrad_8bit_blockwise_fp16, diff --git a/bitsandbytes/optim/lion.py b/bitsandbytes/optim/lion.py index a2fb6af..4a00f57 100644 --- a/bitsandbytes/optim/lion.py +++ b/bitsandbytes/optim/lion.py @@ -19,7 +19,7 @@ class Lion(Optimizer1State): block_wise=True, ): super().__init__( - "rmsprop", + "lion", params, lr, betas, @@ -46,7 +46,7 @@ class Lion8bit(Optimizer1State): block_wise=True, ): super().__init__( - "rmsprop", + "lion", params, lr, betas, @@ -73,7 +73,7 @@ class Lion32bit(Optimizer1State): block_wise=True, ): super().__init__( - "rmsprop", + "lion", params, lr, betas, diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 08b9b44..a871a55 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -790,6 +790,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm break; + case LION: case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value @@ -890,6 +891,7 @@ __global__ void kOptimizer32bit1State(T *g, T *p, p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); break; + case LION: case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); @@ -1219,6 +1221,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c if(unorm != NULL) local_unorm += s1_vals[j]*s1_vals[j]; break; + case LION: case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); break; @@ -1321,6 +1324,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j])); break; + case LION: case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps)); @@ -1664,6 +1668,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char else s1_vals[j] = (s1_vals[j]*beta1) + g_val; break; + case LION: case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); break; @@ -1701,6 +1706,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char case MOMENTUM: p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); break; + case LION: case RMSPROP: g_val = g_vals[j]; p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); @@ -2699,6 +2705,8 @@ MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half) MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) MAKE_PreconditionOptimizer32bit1State(RMSPROP, half) MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) +MAKE_PreconditionOptimizer32bit1State(LION, half) +MAKE_PreconditionOptimizer32bit1State(LION, float) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) @@ -2710,6 +2718,8 @@ MAKE_Optimizer32bit1State(MOMENTUM, half) MAKE_Optimizer32bit1State(MOMENTUM, float) MAKE_Optimizer32bit1State(RMSPROP, half) MAKE_Optimizer32bit1State(RMSPROP, float) +MAKE_Optimizer32bit1State(LION, half) +MAKE_Optimizer32bit1State(LION, float) MAKE_Optimizer32bit1State(ADAGRAD, half) MAKE_Optimizer32bit1State(ADAGRAD, float) @@ -2742,6 +2752,8 @@ MAKE_PreconditionStatic8bit1State(MOMENTUM, half) MAKE_PreconditionStatic8bit1State(MOMENTUM, float) MAKE_PreconditionStatic8bit1State(RMSPROP, half) MAKE_PreconditionStatic8bit1State(RMSPROP, float) +MAKE_PreconditionStatic8bit1State(LION, half) +MAKE_PreconditionStatic8bit1State(LION, float) #define MAKE_optimizerStatic8bit1State(oname, gtype) \ template __global__ void kOptimizerStatic8bit1State(gtype* p, gtype* const g, unsigned char* state1, \ @@ -2758,6 +2770,8 @@ MAKE_optimizerStatic8bit1State(MOMENTUM, half) MAKE_optimizerStatic8bit1State(MOMENTUM, float) MAKE_optimizerStatic8bit1State(RMSPROP, half) MAKE_optimizerStatic8bit1State(RMSPROP, float) +MAKE_optimizerStatic8bit1State(LION, half) +MAKE_optimizerStatic8bit1State(LION, float) #define MAKE_PreconditionStatic8bit2State(oname, gtype) \ template __global__ void kPreconditionOptimizerStatic8bit2State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \ @@ -2849,5 +2863,7 @@ MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8) diff --git a/csrc/ops.cu b/csrc/ops.cu index e770e10..cdd8a27 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -120,6 +120,7 @@ template void optimizer32bit(T* g, T* p, case MOMENTUM: case RMSPROP: case ADAGRAD: + case LION: if(max_unorm > 0.0f) { @@ -163,6 +164,7 @@ template void optimizerStatic8bit(T* p, T* g, case MOMENTUM: case RMSPROP: case ADAGRAD: + case LION: CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); @@ -198,6 +200,7 @@ template void optimizerStatic8bitBlockwise(T* p, T* g case MOMENTUM: case RMSPROP: case ADAGRAD: + case LION: num_blocks = n/BLOCKSIZE_1STATE; num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; kOptimizerStatic8bit1StateBlockwise<<>>(p, g, state1, beta1, beta2, eps, step, lr, @@ -707,6 +710,8 @@ MAKE_optimizer32bit(MOMENTUM, half) MAKE_optimizer32bit(MOMENTUM, float) MAKE_optimizer32bit(RMSPROP, half) MAKE_optimizer32bit(RMSPROP, float) +MAKE_optimizer32bit(LION, half) +MAKE_optimizer32bit(LION, float) MAKE_optimizer32bit(ADAGRAD, half) MAKE_optimizer32bit(ADAGRAD, float) @@ -726,6 +731,8 @@ MAKE_optimizerStatic8bit(MOMENTUM, half) MAKE_optimizerStatic8bit(MOMENTUM, float) MAKE_optimizerStatic8bit(RMSPROP, half) MAKE_optimizerStatic8bit(RMSPROP, float) +MAKE_optimizerStatic8bit(LION, half) +MAKE_optimizerStatic8bit(LION, float) #define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ template void optimizerStatic8bitBlockwise(gtype* p, gtype* g, \ @@ -738,6 +745,8 @@ MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(half, LION); +MAKE_optimizerStatic8bitBlockwise(float, LION); MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 31d4dd8..9f06435 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -70,6 +70,7 @@ typedef enum Optimizer_t RMSPROP = 2, LARS = 3, ADAGRAD = 4, + LION = 5, } Optimizer_t; typedef enum Transform_t diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index d8b2290..4caa7e8 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -33,6 +33,8 @@ MAKE_FUNC32(adam, ADAM, float, 32) MAKE_FUNC32(adam, ADAM, half, 16) MAKE_FUNC32(rmsprop, RMSPROP, float, 32) MAKE_FUNC32(rmsprop, RMSPROP, half, 16) +MAKE_FUNC32(lion, LION, float, 32) +MAKE_FUNC32(lion, LION, half, 16) MAKE_FUNC32(adagrad, ADAGRAD, float, 32) MAKE_FUNC32(adagrad, ADAGRAD, half, 16) @@ -55,6 +57,8 @@ MAKE_FUNC8(momentum, MOMENTUM, float, 32) MAKE_FUNC8(momentum, MOMENTUM, half, 16) MAKE_FUNC8(rmsprop, RMSPROP, float, 32) MAKE_FUNC8(rmsprop, RMSPROP, half, 16) +MAKE_FUNC8(lion, LION, float, 32) +MAKE_FUNC8(lion, LION, half, 16) #define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \ void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \ @@ -68,6 +72,8 @@ MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16) MAKE_BLOCKWISE8(momentum, MOMENTUM, float, 32) MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, 16) MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, 32) +MAKE_BLOCKWISE8(lion, LION, half, 16) +MAKE_BLOCKWISE8(lion, LION, float, 32) MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, 16) MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32) @@ -161,6 +167,8 @@ extern "C" MAKE_CFUNC32(momentum, half, 16) MAKE_CFUNC32(rmsprop, float, 32) MAKE_CFUNC32(rmsprop, half, 16) + MAKE_CFUNC32(lion, float, 32) + MAKE_CFUNC32(lion, half, 16) MAKE_CFUNC32(adagrad, float, 32) MAKE_CFUNC32(adagrad, half, 16) @@ -183,6 +191,8 @@ extern "C" MAKE_CFUNC8(momentum, half, 16) MAKE_CFUNC8(rmsprop, float, 32) MAKE_CFUNC8(rmsprop, half, 16) + MAKE_CFUNC8(lion, float, 32) + MAKE_CFUNC8(lion, half, 16) #define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \ @@ -196,6 +206,8 @@ extern "C" MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32) MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16) MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32) + MAKE_CBLOCKWISE8(lion, LION, half, 16) + MAKE_CBLOCKWISE8(lion, LION, float, 32) MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16) MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32) diff --git a/tests/test_optim.py b/tests/test_optim.py index 3df2dad..a11ba85 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -50,6 +50,10 @@ str2optimizers["rmsprop"] = ( lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False), ) +str2optimizers["rmsprop"] = ( + lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), + lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False), +) str2optimizers["adam8bit"] = ( torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False), From 8de29fc3645075a4fa7f9d37201a5fb2cdc364ca Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 9 Mar 2023 10:11:32 -0800 Subject: [PATCH 16/52] forget about tests for now, will test live on local enwik8 training --- tests/test_optim.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_optim.py b/tests/test_optim.py index a11ba85..3df2dad 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -50,10 +50,6 @@ str2optimizers["rmsprop"] = ( lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False), ) -str2optimizers["rmsprop"] = ( - lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), - lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False), -) str2optimizers["adam8bit"] = ( torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False), From 64bb1ae8d176ca8661bd3f76518e97ec4f863506 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 9 Mar 2023 11:10:28 -0800 Subject: [PATCH 17/52] add a sign function, for lion --- csrc/kernels.cu | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index a871a55..76a8c73 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -217,6 +217,14 @@ __device__ __forceinline__ unsigned char quantize_quadrant(int QUADRANT, float * } } +// sign function for lion +// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA + +template +__device__ int sgn(T val) { + return (T(0) < val) - (val < T(0)); +} + __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n) { const int tid = threadIdx.x + (blockDim.x*blockIdx.x); From c83888aa1aab50fde54ccad19114e781e2fc62a4 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 9 Mar 2023 11:54:54 -0800 Subject: [PATCH 18/52] use epsilon as beta2 for lion, complete most of the logic in kernel.cu for all functions --- bitsandbytes/optim/lion.py | 17 ++++++++------- csrc/kernels.cu | 42 ++++++++++++++++++++++++++++---------- 2 files changed, 41 insertions(+), 18 deletions(-) diff --git a/bitsandbytes/optim/lion.py b/bitsandbytes/optim/lion.py index 4a00f57..81a9efe 100644 --- a/bitsandbytes/optim/lion.py +++ b/bitsandbytes/optim/lion.py @@ -18,12 +18,13 @@ class Lion(Optimizer1State): percentile_clipping=100, block_wise=True, ): + beta1, beta2 = betas super().__init__( "lion", params, lr, - betas, - 0., + (beta1, 0.), + beta2, weight_decay, optim_bits, args, @@ -44,13 +45,14 @@ class Lion8bit(Optimizer1State): min_8bit_size=4096, percentile_clipping=100, block_wise=True, - ): + ): + beta1, beta2 = betas super().__init__( "lion", params, lr, - betas, - 0., + (beta1, 0.), + beta2, weight_decay, 8, args, @@ -72,12 +74,13 @@ class Lion32bit(Optimizer1State): percentile_clipping=100, block_wise=True, ): + beta1, beta2 = betas super().__init__( "lion", params, lr, - betas, - 0., + (beta1, 0.), + beta2, weight_decay, 32, args, diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 76a8c73..553f884 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -43,6 +43,14 @@ __device__ float atomicMin(float* address, float val) { return __int_as_float(old); } +// sign function for lion +// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA + +template +__device__ int sgn(T val) { + return (T(0) < val) - (val < T(0)); +} + template __device__ unsigned char dQuantize(float* smem_code, const float rand, float x) { @@ -217,14 +225,6 @@ __device__ __forceinline__ unsigned char quantize_quadrant(int QUADRANT, float * } } -// sign function for lion -// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA - -template -__device__ int sgn(T val) { - return (T(0) < val) - (val < T(0)); -} - __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n) { const int tid = threadIdx.x + (blockDim.x*blockIdx.x); @@ -799,6 +799,10 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm break; case LION: + // using eps as beta2 + s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*(float)g_vals[j]); // state update + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value @@ -899,7 +903,11 @@ __global__ void kOptimizer32bit1State(T *g, T *p, p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); break; - case LION: + case LION: + // using eps as beta2 + p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])))); + s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*((float)g_vals[j])); + break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); @@ -1230,6 +1238,9 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c local_unorm += s1_vals[j]*s1_vals[j]; break; case LION: + // using eps as beta2 + s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*g_val); + break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); break; @@ -1333,6 +1344,10 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j])); break; case LION: + // using eps as beta2 + p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val)))); + s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*g_val); + break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps)); @@ -1676,7 +1691,10 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char else s1_vals[j] = (s1_vals[j]*beta1) + g_val; break; - case LION: + case LION: + // using eps as beta2 + s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*g_val); + break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); break; @@ -1714,7 +1732,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char case MOMENTUM: p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); break; - case LION: + case LION: + p_vals[j] = ((float)p_vals[j]) - lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))); + break; case RMSPROP: g_val = g_vals[j]; p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); From ead570a43ea8a8205860d6bc294f8735bc8df091 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 9 Mar 2023 11:58:31 -0800 Subject: [PATCH 19/52] remove something rmsprop specific --- csrc/kernels.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 553f884..a59cbb0 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -801,7 +801,6 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, case LION: // using eps as beta2 s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*(float)g_vals[j]); // state update - s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update From af034309922af39b32d1bbdee747b23513ecf2bc Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 9 Mar 2023 14:03:07 -0800 Subject: [PATCH 20/52] fix weight decay for lion to be decoupled, using a switch --- csrc/kernels.cu | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index a59cbb0..8fa5a34 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1328,8 +1328,19 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, { g_val = float(g_vals[j]); g_val *= gnorm_scale; - if(weight_decay > 0.0f) - g_val += ((float)p_vals[j])*weight_decay; + + if(weight_decay > 0.0f) { + switch(OPTIMIZER) { + case MOMENTUM: + case RMSPROP: + g_val += ((float)p_vals[j])*weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); + break; + } + } + s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0]; switch(OPTIMIZER) @@ -1677,8 +1688,17 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char g_val *= gnorm_scale; if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) { - if(weight_decay > 0.0f) - g_val += ((float)p_vals[j])*weight_decay; + if(weight_decay > 0.0f) { + switch(OPTIMIZER) { + case MOMENTUM: + case RMSPROP: + g_val += ((float)p_vals[j])*weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); + break; + } + } s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; From c5582724d5213e558f2a8ceef25e09c115cb8ef8 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 9 Mar 2023 14:05:45 -0800 Subject: [PATCH 21/52] missed adagrad --- csrc/kernels.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 8fa5a34..0c5493b 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1691,6 +1691,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char if(weight_decay > 0.0f) { switch(OPTIMIZER) { case MOMENTUM: + case ADAGRAD: case RMSPROP: g_val += ((float)p_vals[j])*weight_decay; break; From 8618bed001e4888758e6b45c0716b781013056f8 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 10 Mar 2023 08:39:06 -0800 Subject: [PATCH 22/52] swap the order in which momentum and parameters are updated in ops.cu --- csrc/ops.cu | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/csrc/ops.cu b/csrc/ops.cu index cdd8a27..384aff7 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -120,8 +120,6 @@ template void optimizer32bit(T* g, T* p, case MOMENTUM: case RMSPROP: case ADAGRAD: - case LION: - if(max_unorm > 0.0f) { CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); @@ -132,6 +130,18 @@ template void optimizer32bit(T* g, T* p, kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; + case LION: + // in lion, the momentum update after the parameter update + kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); + kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + } + break; } } @@ -164,7 +174,6 @@ template void optimizerStatic8bit(T* p, T* g, case MOMENTUM: case RMSPROP: case ADAGRAD: - case LION: CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); @@ -172,6 +181,16 @@ template void optimizerStatic8bit(T* p, T* g, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; + case LION: + // in lion, the momentum update happens after the parameter update + kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr, + quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + + CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); + kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; default: break; } From c99b44f774e334aff5491fc822fc1cdc816a172f Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 10 Mar 2023 08:57:59 -0800 Subject: [PATCH 23/52] do the epsilon beta2 switcharoo within the cuda code, and not within the python class (so that the state dict still makes sense) --- bitsandbytes/optim/lion.py | 15 ++++++--------- csrc/ops.cu | 8 ++++---- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/bitsandbytes/optim/lion.py b/bitsandbytes/optim/lion.py index 81a9efe..c267af7 100644 --- a/bitsandbytes/optim/lion.py +++ b/bitsandbytes/optim/lion.py @@ -18,13 +18,12 @@ class Lion(Optimizer1State): percentile_clipping=100, block_wise=True, ): - beta1, beta2 = betas super().__init__( "lion", params, lr, - (beta1, 0.), - beta2, + (beta1, beta2), + 0., weight_decay, optim_bits, args, @@ -46,13 +45,12 @@ class Lion8bit(Optimizer1State): percentile_clipping=100, block_wise=True, ): - beta1, beta2 = betas super().__init__( "lion", params, lr, - (beta1, 0.), - beta2, + (beta1, beta2), + 0., weight_decay, 8, args, @@ -74,13 +72,12 @@ class Lion32bit(Optimizer1State): percentile_clipping=100, block_wise=True, ): - beta1, beta2 = betas super().__init__( "lion", params, lr, - (beta1, 0.), - beta2, + betas, + 0., weight_decay, 32, args, diff --git a/csrc/ops.cu b/csrc/ops.cu index 384aff7..51c530e 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -132,13 +132,13 @@ template void optimizer32bit(T* g, T* p, break; case LION: // in lion, the momentum update after the parameter update - kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); if(max_unorm > 0.0f) { CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); - kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n); + kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, beta2, weight_decay, step, lr, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } break; @@ -183,12 +183,12 @@ template void optimizerStatic8bit(T* p, T* g, break; case LION: // in lion, the momentum update happens after the parameter update - kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr, + kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); - kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, beta2, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; default: From 19b9ef34b955fec483803848636f79df3d5a97e3 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 10 Mar 2023 08:59:49 -0800 Subject: [PATCH 24/52] whoops --- bitsandbytes/optim/lion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/optim/lion.py b/bitsandbytes/optim/lion.py index c267af7..2551b68 100644 --- a/bitsandbytes/optim/lion.py +++ b/bitsandbytes/optim/lion.py @@ -22,7 +22,7 @@ class Lion(Optimizer1State): "lion", params, lr, - (beta1, beta2), + betas, 0., weight_decay, optim_bits, @@ -49,7 +49,7 @@ class Lion8bit(Optimizer1State): "lion", params, lr, - (beta1, beta2), + betas, 0., weight_decay, 8, From abbe65adfc42b5fb2ee1b57a73ea76f24f81252c Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 10 Mar 2023 12:50:14 -0800 Subject: [PATCH 25/52] beta2 is actually accessible in kOptimizerStatic8bit1StateBlockwise --- csrc/kernels.cu | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 0c5493b..1833383 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1712,8 +1712,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char s1_vals[j] = (s1_vals[j]*beta1) + g_val; break; case LION: - // using eps as beta2 - s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*g_val); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); From 6c377b39b6f79ea8ec33f87c71dc44ac6b092ac3 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 10 Mar 2023 13:00:59 -0800 Subject: [PATCH 26/52] always pass beta2 into all the 1state functions --- csrc/kernels.cu | 18 ++++++++++-------- csrc/kernels.cuh | 8 ++++---- csrc/ops.cu | 16 ++++++++-------- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 1833383..87ea692 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -751,7 +751,7 @@ template __launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, - const float beta1, const float eps, const float weight_decay, + const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n) { @@ -833,7 +833,7 @@ template __launch_bounds__(TH, 1) __global__ void kOptimizer32bit1State(T *g, T *p, float *state1, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float eps, const float weight_decay, + const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) { @@ -1175,7 +1175,7 @@ __global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, float *unorm, - const float beta1, + const float beta1, const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, float* max1, float* new_max1, @@ -1238,7 +1238,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c break; case LION: // using eps as beta2 - s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*g_val); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); @@ -1265,7 +1265,7 @@ template __global__ void kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, const float *unorm, const float max_unorm, const float param_norm, - const float beta1, + const float beta1, const float beta2, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* max1, float* new_max1, @@ -1356,7 +1356,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, case LION: // using eps as beta2 p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val)))); - s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*g_val); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); @@ -2745,7 +2745,7 @@ template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *c #define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ template __global__ void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ float* state1, float *unorm, \ - const float beta1, const float eps, const float weight_decay, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ const int step, const float lr, const float gnorm_scale, const int n); \ MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half) @@ -2759,7 +2759,7 @@ MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) #define MAKE_Optimizer32bit1State(oname, gtype) \ template __global__ void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ - const float beta1, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \ + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \ MAKE_Optimizer32bit1State(MOMENTUM, half) MAKE_Optimizer32bit1State(MOMENTUM, float) @@ -2788,6 +2788,7 @@ template __global__ void kOptimizer32bit2State(float* g, float* p, template __global__ void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ float *unorm, \ const float beta1, \ + const float beta2, \ const float eps, const int step, \ float* __restrict__ const quantiles1, \ float* max1, float* new_max1, \ @@ -2806,6 +2807,7 @@ MAKE_PreconditionStatic8bit1State(LION, float) template __global__ void kOptimizerStatic8bit1State(gtype* p, gtype* const g, unsigned char* state1, \ const float *unorm, const float max_unorm, const float param_norm, \ const float beta1, \ + const float beta2, \ const float eps, const int step, const float lr, \ float* __restrict__ const quantiles1, \ float* max1, float* new_max1, \ diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index d90ea13..a8aa3fc 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -32,20 +32,20 @@ __global__ void kOptimizer32bit2State(T* g, T* p, template __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, - const float beta1, const float eps, const float weight_decay, + const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n); template __global__ void kOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float eps, const float weight_decay, + const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, float *unorm, - const float beta1, + const float beta1, const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, float* max1, float* new_max1, @@ -57,7 +57,7 @@ template __global__ void kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, const float *unorm, const float max_unorm, const float param_norm, - const float beta1, + const float beta1, const float beta2, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* max1, float* new_max1, diff --git a/csrc/ops.cu b/csrc/ops.cu index 51c530e..94d5f2e 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -123,22 +123,22 @@ template void optimizer32bit(T* g, T* p, if(max_unorm > 0.0f) { CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); - kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n); + kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } - kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; case LION: // in lion, the momentum update after the parameter update - kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); if(max_unorm > 0.0f) { CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); - kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, beta2, weight_decay, step, lr, gnorm_scale, n); + kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } break; @@ -175,20 +175,20 @@ template void optimizerStatic8bit(T* p, T* g, case RMSPROP: case ADAGRAD: CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); - kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); - kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr, + kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; case LION: // in lion, the momentum update happens after the parameter update - kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, step, lr, + kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); - kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, beta2, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; default: From 369a51c432c310a74de6e185ea072af4a398ec67 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 10 Mar 2023 14:08:35 -0800 Subject: [PATCH 27/52] switch all eps to beta2 --- csrc/kernels.cu | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 87ea692..98a3188 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -799,8 +799,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm break; case LION: - // using eps as beta2 - s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*(float)g_vals[j]); // state update + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*(float)g_vals[j]); // state update break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update @@ -903,9 +902,8 @@ __global__ void kOptimizer32bit1State(T *g, T *p, p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); break; case LION: - // using eps as beta2 p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])))); - s1_vals[j] = s1_vals[j]*eps + ((1.0f-eps)*((float)g_vals[j])); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j])); break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); @@ -1237,7 +1235,6 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c local_unorm += s1_vals[j]*s1_vals[j]; break; case LION: - // using eps as beta2 s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); break; case RMSPROP: @@ -1354,7 +1351,6 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j])); break; case LION: - // using eps as beta2 p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val)))); s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); break; From c4866ab06e9f65b400c609fcffdbaa35d8d2c039 Mon Sep 17 00:00:00 2001 From: Severin Gsponer Date: Sat, 11 Mar 2023 15:35:23 +0100 Subject: [PATCH 28/52] Fix #157; Add XDG_GREETER_DATA_DIR to ignorelist --- bitsandbytes/cuda_setup/env_vars.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bitsandbytes/cuda_setup/env_vars.py b/bitsandbytes/cuda_setup/env_vars.py index 536a7d8..4fcb643 100644 --- a/bitsandbytes/cuda_setup/env_vars.py +++ b/bitsandbytes/cuda_setup/env_vars.py @@ -11,6 +11,7 @@ def to_be_ignored(env_var: str, value: str) -> bool: "HOME", # Linux shell default "TMUX", # Terminal Multiplexer "XDG_DATA_DIRS", # XDG: Desktop environment stuff + "XDG_GREETER_DATA_DIR", # XDG: Desktop environment stuff "XDG_RUNTIME_DIR", "MAIL", # something related to emails "SHELL", # binary for currently invoked shell From 2c8352e316d5428f57f47ec8b557dc9c9caf427f Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 12 Mar 2023 10:24:25 -0700 Subject: [PATCH 29/52] Bumped version. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e3f453e..442577c 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ def read(fname): setup( name=f"bitsandbytes", - version=f"0.37.0", + version=f"0.37.1", author="Tim Dettmers", author_email="dettmers@cs.washington.edu", description="8-bit optimizers and matrix multiplication routines.", From 1b0aabc7e4c8a4a78ff9b3a0aec199d55dc27135 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 21 Mar 2023 14:06:08 -0700 Subject: [PATCH 30/52] Added CUDA 12.1. addressing #201 --- cuda_install.sh | 8 ++++++-- deploy.sh | 20 ++++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/cuda_install.sh b/cuda_install.sh index d84eb56..0f888b9 100644 --- a/cuda_install.sh +++ b/cuda_install.sh @@ -12,6 +12,7 @@ URL116=https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installer URL117=https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run URL118=https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run URL120=https://developer.download.nvidia.com/compute/cuda/12.0.0/local_installers/cuda_12.0.0_525.60.13_linux.run +URL121=https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run CUDA_VERSION=$1 @@ -60,11 +61,14 @@ if [[ -n "$CUDA_VERSION" ]]; then elif [[ "$CUDA_VERSION" -eq "120" ]]; then URL=$URL120 FOLDER=cuda-12.0 + elif [[ "$CUDA_VERSION" -eq "121" ]]; then + URL=$URL121 + FOLDER=cuda-12.1 else - echo "argument error: No cuda version passed as input. Choose among: {111, 115}" + echo "argument error: No cuda version passed as input. Choose among versions 92 to 121" fi else - echo "argument error: No cuda version passed as input. Choose among: {111, 115}" + echo "argument error: No cuda version passed as input. Choose among versions 92 to 112" fi FILE=$(basename $URL) diff --git a/deploy.sh b/deploy.sh index d473bb5..44db970 100644 --- a/deploy.sh +++ b/deploy.sh @@ -128,6 +128,16 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda120.so" ]; then exit 64 fi +make clean +export CUDA_HOME=$BASE_PATH/cuda-12.1 +make cuda12x CUDA_VERSION=121 + +if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + make clean export CUDA_HOME=$BASE_PATH/cuda-10.2 @@ -241,5 +251,15 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda120_nocublaslt.so" ]; then exit 64 fi +make clean +export CUDA_HOME=$BASE_PATH/cuda-12.1 +make cuda12x_nomatmul CUDA_VERSION=121 + +if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121_nocublaslt.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + python -m build python -m twine upload dist/* --verbose From 49a04253fb1f3e195cb0e9e79bdb01db7a490774 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 21 Mar 2023 15:10:19 -0700 Subject: [PATCH 31/52] Bumped version for CUDA 12.1 support release. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 442577c..a0bbc7f 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ def read(fname): setup( name=f"bitsandbytes", - version=f"0.37.1", + version=f"0.37.2", author="Tim Dettmers", author_email="dettmers@cs.washington.edu", description="8-bit optimizers and matrix multiplication routines.", From dcecbb26cafc040052e78b09b1cbe06929a9b776 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Wed, 22 Mar 2023 00:28:49 +0100 Subject: [PATCH 32/52] Add force_no_igemmlt to test params --- tests/test_linear8bitlt.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 1aafe3d..37f7af9 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -69,9 +69,9 @@ def test_linear_no_igemmlt(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") -@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda", - list(product([False, True], [False, True], [False, True]))) -def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda): +@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt", + list(product([False, True], [False, True], [False, True], [False, True]))) +def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt): linear = torch.nn.Linear(32, 96) x = torch.randn(3, 32, dtype=torch.half) @@ -82,6 +82,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri has_fp16_weights=has_fp16_weights, threshold=6.0, ) + if force_no_igemmlt: + linear_custom.state.force_no_igemmlt = True + linear_custom.weight = bnb.nn.Int8Params( linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights ) @@ -118,6 +121,8 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri has_fp16_weights=has_fp16_weights, threshold=6.0, ) + if force_no_igemmlt: + new_linear_custom.state.force_no_igemmlt = True if deserialize_before_cuda: with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError): From 9b656f461acc8ed8a461382828e042799eb8e402 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 22 Mar 2023 07:52:59 -0700 Subject: [PATCH 33/52] follow advice of Tim to fix update of momentum vs parameters in blockwise 8 bit --- csrc/kernels.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 98a3188..3c3445f 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1708,6 +1708,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char s1_vals[j] = (s1_vals[j]*beta1) + g_val; break; case LION: + g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))); s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); break; case RMSPROP: @@ -1748,7 +1749,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); break; case LION: - p_vals[j] = ((float)p_vals[j]) - lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))); + p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]); break; case RMSPROP: g_val = g_vals[j]; From a43cd2008d2f369aa7e47624cae942f87b7f8d6f Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 22 Mar 2023 09:14:05 -0700 Subject: [PATCH 34/52] add some code in test_optim.py, although it seems to be failing --- requirements.txt | 1 + tests/test_optim.py | 23 ++++++++++++++++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e079f8a..883b2e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ +lion-pytorch pytest diff --git a/tests/test_optim.py b/tests/test_optim.py index 3df2dad..9f815ab 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -7,6 +7,8 @@ from itertools import product from os.path import join import pytest +from lion_pytorch import Lion + import torch import bitsandbytes as bnb @@ -31,6 +33,7 @@ str2optimizers = {} str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam) # str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam) # str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam) +str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion) str2optimizers["momentum_pytorch"] = ( None, lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), @@ -38,6 +41,7 @@ str2optimizers["momentum_pytorch"] = ( ) str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam) # str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) +str2optimizers["lion"] = (Lion, bnb.optim.Lion) str2optimizers["momentum"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False), @@ -54,6 +58,10 @@ str2optimizers["adam8bit"] = ( torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False), ) +str2optimizers["lion8bit"] = ( + Lion, + lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False), +) str2optimizers["momentum8bit"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False), @@ -71,6 +79,10 @@ str2optimizers["adam8bit_blockwise"] = ( torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True), ) +str2optimizers["lion8bit_blockwise"] = ( + Lion, + lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True), +) str2optimizers["momentum8bit_blockwise"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True), @@ -82,6 +94,7 @@ str2optimizers["rmsprop8bit_blockwise"] = ( str2statenames = {} str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] +str2statenames["lion"] = [("exp_avg", "state1")] str2statenames["momentum"] = [("momentum_buffer", "state1")] str2statenames["lars"] = [("momentum_buffer", "state1")] str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] @@ -90,6 +103,9 @@ str2statenames["adam8bit"] = [ ("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2"), ] +str2statenames["lion8bit"] = [ + ("exp_avg", "state1", "qmap1", "max1") +] str2statenames["lamb8bit"] = [ ("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2"), @@ -98,6 +114,9 @@ str2statenames["adam8bit_blockwise"] = [ ("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2"), ] +str2statenames["lion8bit_blockwise"] = [ + ("exp_avg", "state1", "qmap1", "absmax1") +] str2statenames["momentum8bit"] = [ ("momentum_buffer", "state1", "qmap1", "max1") ] @@ -113,7 +132,7 @@ str2statenames["rmsprop8bit_blockwise"] = [ dim1 = [1024] dim2 = [32, 1024, 4097, 1] gtype = [torch.float32, torch.float16] -optimizer_names = ["adam", "momentum", "rmsprop", "lars"] +optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lion"] values = list(product(dim1, dim2, gtype, optimizer_names)) names = [ "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values @@ -241,9 +260,11 @@ dim2 = [32, 1024, 4097] gtype = [torch.float32, torch.float16] optimizer_names = [ "adam8bit", + "lion8bit", "momentum8bit", "rmsprop8bit", "adam8bit_blockwise", + "lion8bit_blockwise", "lars8bit", "momentum8bit_blockwise", "rmsprop8bit_blockwise", From aa9b939edda7779118bcb66618454c9c8c57e986 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 22 Mar 2023 09:22:19 -0700 Subject: [PATCH 35/52] add some comments, and fix use of g_val --- csrc/kernels.cu | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 3c3445f..253576a 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1708,7 +1708,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char s1_vals[j] = (s1_vals[j]*beta1) + g_val; break; case LION: - g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))); + // here, using gvals[j] to store the gradient smoothed by beta1 + // then update the momentum state1, to make sure the order is correct + g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val)); s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); break; case RMSPROP: From 916000c8bfab0611f7698852cd8b688a48597d22 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 22 Mar 2023 09:27:13 -0700 Subject: [PATCH 36/52] fix consistent tabs / spaces --- csrc/kernels.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 253576a..1c3a782 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1708,9 +1708,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char s1_vals[j] = (s1_vals[j]*beta1) + g_val; break; case LION: - // here, using gvals[j] to store the gradient smoothed by beta1 - // then update the momentum state1, to make sure the order is correct - g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val)); + // here, using gvals[j] to store the gradient smoothed by beta1 + // then update the momentum state1, to make sure the order is correct + g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val)); s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); break; case RMSPROP: From 978ba2db57e473d3c4351497b399db65c92d9522 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 22 Mar 2023 09:33:47 -0700 Subject: [PATCH 37/52] another tab/spaces fix --- csrc/kernels.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 1c3a782..1778c98 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1682,10 +1682,10 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char { g_val = float(g_vals[j]); g_val *= gnorm_scale; - if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) - { - if(weight_decay > 0.0f) { - switch(OPTIMIZER) { + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + if(weight_decay > 0.0f) { + switch(OPTIMIZER) { case MOMENTUM: case ADAGRAD: case RMSPROP: From 2a6828e6fbc4bd81d44ca67de45563b3b2876c14 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 22 Mar 2023 09:56:50 -0700 Subject: [PATCH 38/52] fix comment --- csrc/kernels.cu | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 1778c98..e0df802 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1708,8 +1708,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char s1_vals[j] = (s1_vals[j]*beta1) + g_val; break; case LION: - // here, using gvals[j] to store the gradient smoothed by beta1 - // then update the momentum state1, to make sure the order is correct + // here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, before the momentum is updated by beta2 g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val)); s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); break; From b6383ba116f987796886ebd4f6126fd064511e3b Mon Sep 17 00:00:00 2001 From: Ji Lin Date: Wed, 22 Mar 2023 22:14:57 -0400 Subject: [PATCH 39/52] fix a bug in quantize_no_absmax and dequantize_no_absmax with multiple gpus --- bitsandbytes/functional.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 95a7c4f..3f7b328 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -656,9 +656,11 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: torch.Tensor: Quantized 8-bit tensor. ''' + prev_device = pre_call(A.device) if out is None: out = torch.zeros_like(A, dtype=torch.uint8) is_on_gpu([A, out]) lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) + post_call(prev_device) return out @@ -683,9 +685,11 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: torch.Tensor: 32-bit output tensor. ''' + prev_device = pre_call(A.device) if out is None: out = torch.zeros_like(A, dtype=torch.float32) is_on_gpu([code, A, out]) lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) + post_call(prev_device) return out From 8cceff72db4d3ac1022dd239e1a91caae9c2f91b Mon Sep 17 00:00:00 2001 From: Jeongseok Kang Date: Wed, 5 Apr 2023 09:28:41 +0900 Subject: [PATCH 40/52] Fixed typo libsbitsandbytes_cpu.so --- bitsandbytes/cuda_setup/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py index ffa80ba..5f06890 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/cuda_setup/main.py @@ -365,7 +365,7 @@ def evaluate_cuda_setup(): print('='*35 + 'BUG REPORT' + '='*35) print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues') print('='*80) - if not torch.cuda.is_available(): return 'libsbitsandbytes_cpu.so', None, None, None, None + if not torch.cuda.is_available(): return 'libbitsandbytes_cpu.so', None, None, None, None cuda_setup = CUDASetup.get_instance() cudart_path = determine_cuda_runtime_lib_path() From 5e456be50e66a266d4c36e8e405c11f634edb8f0 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Mon, 10 Apr 2023 21:26:52 +0300 Subject: [PATCH 41/52] Support 1650, 1660 --- bitsandbytes/autograd/_functions.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 376fb8a..4db9a92 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -221,6 +221,17 @@ bmm_cublas = MatMul8bit.apply matmul_cublas = MatMul8bit.apply +def supports_igemmlt(device: torch.device) -> bool: + """check if this device supports the optimized int8 kernel""" + if torch.cuda.get_device_capability(device=device) < (7, 5): + return False + device_name = torch.cuda.get_device_name(device=device) + nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series + if any(model_name in device_name for model_name in nvidia16_models): + return False # these devices are technically cuda 7.5-capable, but they lack tensor cores + return True + + @dataclass class MatmulLtState: tile_indices: Optional[torch.Tensor] = None @@ -270,7 +281,7 @@ class MatMul8bitLt(torch.autograd.Function): @staticmethod def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): - using_igemmlt = torch.cuda.get_device_capability(device=A.device) >= (7, 5) and not state.force_no_igemmlt + using_igemmlt = supports_igemmlt(A.device) and not state.force_no_igemmlt # default of pytorch behavior if inputs are empty ctx.is_empty = False if prod(A.shape) == 0: From 0b2ebcdab96022a558985685432faa9620d63647 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 11 Apr 2023 08:37:02 -0700 Subject: [PATCH 42/52] Added launch bounds to fix launch resource error for Lion. --- csrc/kernels.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index e0df802..e1ec00d 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1260,6 +1260,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c template __global__ void +__launch_bounds__(1024, 1) kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, const float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, From 792af5c8838568d47e6421fece9dcb7460b20adc Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 11 Apr 2023 08:42:41 -0700 Subject: [PATCH 43/52] Fixed noisy tests for 8-bit Lion. --- tests/test_optim.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/tests/test_optim.py b/tests/test_optim.py index 9f815ab..96c2a7b 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -18,6 +18,13 @@ import bitsandbytes.functional as F k = 20 +def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0): + idx = torch.isclose(a, b, rtol, atol) + error_count = (idx == 0).sum().item() + if error_count > max_error_count: + print(f"Too many values not close: assert {sumval} < {count}") + torch.testing.assert_allclose(a, b, rtol, atol) + def get_temp_dir(): path = f"/tmp/autoswap/{str(uuid.uuid4())}" @@ -306,7 +313,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): bnb_optimizer.step() torch_optimizer.step() - torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol) + # since Lion can have pretty noisy updates where things lie at the boundary + # allow up to 5 errors for Lion + assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=5) dequant_states = [] for name1, name2, qmap, max_val in str2statenames[optim_name]: @@ -388,9 +397,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): == 0 ) assert num_not_close.sum().item() < 20 - torch.testing.assert_allclose( - p1, p2.float(), atol=patol, rtol=prtol - ) + # since Lion can have pretty noisy updates where things lie at the boundary + # allow up to 5 errors for Lion + assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=5) # the parameters diverge quickly. Here we keep them close # together so we can test against the Adam error From 2eb310835668f854c169953814f1d3b16a44346b Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 11 Apr 2023 09:16:01 -0700 Subject: [PATCH 44/52] Fixed bug where beta2 was not passed into Lion 32-bit. --- bitsandbytes/optim/optimizer.py | 2 +- tests/test_optim.py | 24 ++++++++++++++---------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 867ad3d..1adf5d4 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -665,7 +665,7 @@ class Optimizer1State(Optimizer8bit): step, config["lr"], None, - 0.0, + config['betas'][1], config["weight_decay"], gnorm_scale, state["unorm_vec"] if config["max_unorm"] > 0.0 else None, diff --git a/tests/test_optim.py b/tests/test_optim.py index 96c2a7b..839f80c 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -22,7 +22,7 @@ def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0): idx = torch.isclose(a, b, rtol, atol) error_count = (idx == 0).sum().item() if error_count > max_error_count: - print(f"Too many values not close: assert {sumval} < {count}") + print(f"Too many values not close: assert {error_count} < {max_error_count}") torch.testing.assert_allclose(a, b, rtol, atol) @@ -170,6 +170,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): bnb_optimizer.step() torch_optimizer.step() + for name1, name2 in str2statenames[optim_name]: torch.testing.assert_allclose( torch_optimizer.state[p1][name1], @@ -178,7 +179,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): rtol=rtol, ) - torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol) + # since Lion can have pretty noisy updates where things lie at the boundary + # allow up to 10 errors for Lion + assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10) if i % (k // 5) == 0 and i > 0: path = get_temp_dir() @@ -188,14 +191,15 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) rm_path(path) - torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol) + # since Lion can have pretty noisy updates where things lie at the boundary + # allow up to 10 errors for Lion + assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10) for name1, name2 in str2statenames[optim_name]: - torch.testing.assert_allclose( - torch_optimizer.state[p1][name1], - bnb_optimizer.state[p2][name2], - atol=atol, - rtol=rtol, - ) + # since Lion can have pretty noisy updates where things lie at the boundary + # allow up to 10 errors for Lion + assert_most_approx_close(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], + atol=atol, rtol=rtol, + max_error_count=10) if gtype == torch.float16: # the adam buffers should also be close because they are 32-bit @@ -343,7 +347,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): dequant_states.append(s1.clone()) err = torch.abs(p1 - p2) - relerr = err / torch.abs(p1) + relerr = err / (torch.abs(p1)+1e-9) assert err.mean() < 0.0001 assert relerr.mean() < 0.001 From 29ab3a6b1433a4100e62f54b1382f4dd1028c06f Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 11 Apr 2023 09:26:52 -0700 Subject: [PATCH 45/52] Updated change log. --- CHANGELOG.md | 14 ++++++++++++++ setup.py | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ac239de..66578d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -201,3 +201,17 @@ Features: Improvements: - Improved logging for the CUDA detection mechanism. + +### 0.38.0 + +#### 8-bit Lion, Load/Store 8-bit layers + +Features: + - Support for 32 and 8-bit Lion has been added. Thank you @lucidrains + - Support for serialization of Linear8bitLt layers (LLM.int8()). This allows to store and load 8-bit weights directly from the HuggingFace Hub. Thank you @myrab + +Bug fixes: + - Fixed a bug where some bitsandbytes methods failed in a model-parallel setup on multiple GPUs. Thank you @tonylins + +Deprecated: + - Devices with compute capability 3.0 (GTX 700s, K10) and 3.2 (Tegra K1, Jetson TK1) are now deprecated and support will be removed in 0.39.0. diff --git a/setup.py b/setup.py index a0bbc7f..b023c0b 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ def read(fname): setup( name=f"bitsandbytes", - version=f"0.37.2", + version=f"0.38.0", author="Tim Dettmers", author_email="dettmers@cs.washington.edu", description="8-bit optimizers and matrix multiplication routines.", From 2bb5c00ba9b0af840e9226a6100f2e968c0763f4 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 11 Apr 2023 09:36:56 -0700 Subject: [PATCH 46/52] Added pre/post call to all lib calls. Fixes #120 --- bitsandbytes/functional.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 9840b47..8d95789 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -770,6 +770,8 @@ def optimizer_update_32bit( f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}' ) + prev_device = pre_call(g.device) + is_on_gpu([g, p, state1, state2, unorm_vec]) if g.dtype == torch.float32 and state1.dtype == torch.float32: str2optimizer32bit[optimizer_name][0]( get_ptr(g), @@ -812,6 +814,7 @@ def optimizer_update_32bit( raise ValueError( f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" ) + post_call(prev_device) def optimizer_update_8bit( @@ -890,6 +893,8 @@ def optimizer_update_8bit( if max_unorm > 0.0: param_norm = torch.norm(p.data.float()) + prev_device = pre_call(g.device) + is_on_gpu([g, p, state1, state2, unorm_vec, qmap1, qmap2, max1, max2, new_max1, new_max2]) if g.dtype == torch.float32 and state1.dtype == torch.uint8: str2optimizer8bit[optimizer_name][0]( get_ptr(p), @@ -942,6 +947,7 @@ def optimizer_update_8bit( raise ValueError( f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" ) + post_call(prev_device) def optimizer_update_8bit_blockwise( @@ -964,6 +970,8 @@ def optimizer_update_8bit_blockwise( skip_zeros=False, ) -> None: + prev_device = pre_call(g.device) + is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2]) if g.dtype == torch.float32 and state1.dtype == torch.uint8: str2optimizer8bit_blockwise[optimizer_name][0]( get_ptr(p), @@ -1008,6 +1016,7 @@ def optimizer_update_8bit_blockwise( raise ValueError( f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" ) + post_call(prev_device) def percentile_clipping( @@ -1023,6 +1032,7 @@ def percentile_clipping( The current optimiation steps (number of past gradient norms). """ + prev_device = pre_call(grad.device) is_on_gpu([grad, gnorm_vec]) if grad.dtype == torch.float32: lib.cpercentile_clipping_g32( @@ -1040,6 +1050,7 @@ def percentile_clipping( ) else: raise ValueError(f"Gradient type {grad.dtype} not supported!") + post_call(prev_device) current_gnorm = torch.sqrt(gnorm_vec[step % 100]) vals, idx = torch.sort(gnorm_vec) @@ -1796,6 +1807,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): (cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype ) nnz = cooA.nnz + prev_device = pre_call(B.device) assert cooA.rowidx.numel() == nnz assert cooA.colidx.numel() == nnz assert cooA.values.numel() == nnz @@ -1872,6 +1884,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): ccolsB, ) # else: assertion error + post_call(prev_device) return out From 4cd63deff3b3cdc923e151c4efdaa281fa42d668 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 11 Apr 2023 12:10:20 -0700 Subject: [PATCH 47/52] Fixed CUDA Conda PyTorch 2.0 issues. --- README.md | 17 +++++++ bitsandbytes/cextension.py | 6 ++- bitsandbytes/cuda_setup/main.py | 28 +++++++---- cuda_install.sh | 13 +++-- tests/test_cuda_setup_evaluator.py | 79 +----------------------------- 5 files changed, 50 insertions(+), 93 deletions(-) diff --git a/README.md b/README.md index dfd91cd..de6b27b 100644 --- a/README.md +++ b/README.md @@ -14,8 +14,25 @@ Resources: Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0. LLM.int8() requires Turing or Ampere GPUs. **Installation**: + ``pip install bitsandbytes`` +In some cases it can happen that you need to compile from source. In that case, you can install CUDA with the install script in the repository. No sudo is required for this install. + +```bash +wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh +# Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH +# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121} +# EXPORT_TO_BASH in {0, 1} with 0=False and 1=True + +# For example, the following installs CUDA 11.8 to ~/local/cuda-11.8 and exports the path to your .bashrc +bash cuda install 118 ~/local 1 +``` + +To use a specific CUDA version just for a single compile run, you can set the variable `CUDA_HOME`, for example the following command compiles `libbitsandbytes_cuda117.so` using compiler flags for cuda11x with the cuda version at `~/local/cuda-11.7`: + +``CUDA_HOME=~/local/cuda-11.7 CUDA_VERSION=117 make cuda11x`` + **Using 8-bit optimizer**: 1. Comment out optimizer: ``#torch.optim.Adam(....)`` 2. Add 8-bit optimizer of your choice ``bnb.optim.Adam8bit(....)`` (arguments stay the same) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 7a62c1e..85bef00 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -11,8 +11,6 @@ from bitsandbytes.cuda_setup.main import CUDASetup setup = CUDASetup.get_instance() if setup.initialized != True: setup.run_cuda_setup() - if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': - setup.print_log_stack() lib = setup.lib try: @@ -31,3 +29,7 @@ except AttributeError: warn("The installed version of bitsandbytes was compiled without GPU support. " "8-bit optimizers and GPU quantization are unavailable.") COMPILED_WITH_CUDA = False + +# print the setup details after checking for errors so we do not print twice +if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': + setup.print_log_stack() diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py index 5f06890..776bee5 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/cuda_setup/main.py @@ -21,12 +21,21 @@ import os import errno import torch from warnings import warn +from itertools import product from pathlib import Path from typing import Set, Union from .env_vars import get_potentially_lib_path_containing_env_vars -CUDA_RUNTIME_LIB: str = "libcudart.so" +# these are the most common libs names +# libcudart.so is missing by default for a conda install with PyTorch 2.0 and instead +# we have libcudart.so.11.0 which causes a lot of errors before +# not sure if libcudart.so.12.0 exists in pytorch installs, but it does not hurt +CUDA_RUNTIME_LIBS: list = ["libcudart.so", 'libcudart.so.11.0', 'libcudart.so.12.0'] + +# this is a order list of backup paths to search CUDA in, if it cannot be found in the main environmental paths +backup_paths = [] +backup_paths.append('$CONDA_PREFIX/lib/libcudart.so.11.0') class CUDASetup: _instance = None @@ -98,6 +107,8 @@ class CUDASetup: package_dir = Path(__file__).parent.parent binary_path = package_dir / binary_name + print('bin', binary_path) + try: if not binary_path.exists(): self.add_log_entry(f"CUDA SETUP: Required library version not found: {binary_name}. Maybe you need to compile it from source?") @@ -117,7 +128,6 @@ class CUDASetup: self.add_log_entry('='*80) self.add_log_entry('') self.generate_instructions() - self.print_log_stack() raise Exception('CUDA SETUP: Setup Failed!') self.lib = ct.cdll.LoadLibrary(binary_path) else: @@ -125,7 +135,6 @@ class CUDASetup: self.lib = ct.cdll.LoadLibrary(binary_path) except Exception as ex: self.add_log_entry(str(ex)) - self.print_log_stack() def add_log_entry(self, msg, is_warning=False): self.cuda_setup_log.append((msg, is_warning)) @@ -178,11 +187,12 @@ def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]: def get_cuda_runtime_lib_paths(candidate_paths: Set[Path]) -> Set[Path]: - return { - path / CUDA_RUNTIME_LIB - for path in candidate_paths - if (path / CUDA_RUNTIME_LIB).is_file() - } + paths = set() + for libname in CUDA_RUNTIME_LIBS: + for path in candidate_paths: + if (path / libname).is_file(): + paths.add(path / libname) + return paths def resolve_paths_list(paths_list_candidate: str) -> Set[Path]: @@ -257,7 +267,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]: cuda_runtime_libs.update(find_cuda_lib_in(value)) if len(cuda_runtime_libs) == 0: - CUDASetup.get_instance().add_log_entry('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching /usr/local/cuda/lib64...') + CUDASetup.get_instance().add_log_entry('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching in backup paths...') cuda_runtime_libs.update(find_cuda_lib_in('/usr/local/cuda/lib64')) warn_in_case_of_duplicates(cuda_runtime_libs) diff --git a/cuda_install.sh b/cuda_install.sh index 0f888b9..b333f33 100644 --- a/cuda_install.sh +++ b/cuda_install.sh @@ -17,6 +17,7 @@ URL121=https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installer CUDA_VERSION=$1 BASE_PATH=$2 +EXPORT_BASHRC=$3 if [[ -n "$CUDA_VERSION" ]]; then if [[ "$CUDA_VERSION" -eq "92" ]]; then @@ -76,11 +77,13 @@ FILE=$(basename $URL) if [[ -n "$CUDA_VERSION" ]]; then echo $URL echo $FILE - wget $URL - bash $FILE --no-drm --no-man-page --override --toolkitpath=$BASE_PATH/$FOLDER/ --toolkit --silent - echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$BASE_PATH/$FOLDER/lib64/" >> ~/.bashrc - echo "export PATH=$PATH:$BASE_PATH/$FOLDER/bin/" >> ~/.bashrc - source ~/.bashrc + #wget $URL + #bash $FILE --no-drm --no-man-page --override --toolkitpath=$BASE_PATH/$FOLDER/ --toolkit --silent + if [ "$EXPORT_BASHRC" -eq "1" ]; then + echo "export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:$BASE_PATH/$FOLDER/lib64" >> ~/.bashrc + echo "export PATH=\$PATH:$BASE_PATH/$FOLDER/bin" >> ~/.bashrc + source ~/.bashrc + fi else echo "" fi diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index c0da1d3..4973da5 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -5,95 +5,20 @@ import pytest import bitsandbytes as bnb from bitsandbytes.cuda_setup.main import ( - CUDA_RUNTIME_LIB, determine_cuda_runtime_lib_path, evaluate_cuda_setup, extract_candidate_paths, ) -""" -'LD_LIBRARY_PATH': ':/mnt/D/titus/local/cuda-11.1/lib64/' -'CONDA_EXE': '/mnt/D/titus/miniconda/bin/conda' -'LESSCLOSE': '/usr/bin/lesspipe %s %s' -'OLDPWD': '/mnt/D/titus/src' -'CONDA_PREFIX': '/mnt/D/titus/miniconda/envs/8-bit' -'SSH_AUTH_SOCK': '/mnt/D/titus/.ssh/ssh-agent.tim-uw.sock' -'CONDA_PREFIX_1': '/mnt/D/titus/miniconda' -'PWD': '/mnt/D/titus/src/8-bit' -'HOME': '/mnt/D/titus' -'CONDA_PYTHON_EXE': '/mnt/D/titus/miniconda/bin/python' -'CUDA_HOME': '/mnt/D/titus/local/cuda-11.1/' -'TMUX': '/tmp/tmux-1007/default,59286,1' -'XDG_DATA_DIRS': '/usr/local/share:/usr/share:/var/lib/snapd/desktop' -'SSH_TTY': '/dev/pts/0' -'MAIL': '/var/mail/titus' -'SHELL': '/bin/bash' -'DBUS_SESSION_BUS_ADDRESS': 'unix:path=/run/user/1007/bus' -'XDG_RUNTIME_DIR': '/run/user/1007' -'PATH': '/mnt/D/titus/miniconda/envs/8-bit/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin:/mnt/D/titus/local/cuda-11.1/bin' -'LESSOPEN': '| /usr/bin/lesspipe %s' -'_': '/mnt/D/titus/miniconda/envs/8-bit/bin/python' -# any that include 'CONDA' that are not 'CONDA_PREFIX' -# we search for -'CUDA_HOME': '/mnt/D/titus/local/cuda-11.1/' -""" - - -class InputAndExpectedOutput(NamedTuple): - input: str - output: str - - -HAPPY_PATH__LD_LIB_TEST_PATHS: List[InputAndExpectedOutput] = [ - ( - f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", - f"dir/with/{CUDA_RUNTIME_LIB}", - ), - ( - f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", - f"dir/with/{CUDA_RUNTIME_LIB}", - ), - ( - f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:", - f"dir/with/{CUDA_RUNTIME_LIB}", - ), - ( - f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}", - f"dir/with/{CUDA_RUNTIME_LIB}", - ), - ( - f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir", - f"dir/with/{CUDA_RUNTIME_LIB}", - ), - ( - f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so", - f"dir/with/{CUDA_RUNTIME_LIB}", - ), -] - - -@pytest.fixture(params=HAPPY_PATH__LD_LIB_TEST_PATHS) -def happy_path_path_string(tmpdir, request): - for path in extract_candidate_paths(request.param): - test_dir.mkdir() - if CUDA_RUNTIME_LIB in path: - (test_input / CUDA_RUNTIME_LIB).touch() - -UNHAPPY_PATH__LD_LIB_TEST_PATHS = [ - f"a/b/c/{CUDA_RUNTIME_LIB}:d/e/f/{CUDA_RUNTIME_LIB}", - f"a/b/c/{CUDA_RUNTIME_LIB}:d/e/f/{CUDA_RUNTIME_LIB}:g/h/j/{CUDA_RUNTIME_LIB}", -] - - -def test_full_system(): +def test_cuda_full_system(): ## this only tests the cuda version and not compute capability # if CONDA_PREFIX exists, it has priority before all other env variables # but it does not contain the library directly, so we need to look at the a sub-folder version = "" if "CONDA_PREFIX" in os.environ: - ls_output, err = bnb.utils.execute_and_return(f'ls -l {os.environ["CONDA_PREFIX"]}/lib/libcudart.so') + ls_output, err = bnb.utils.execute_and_return(f'ls -l {os.environ["CONDA_PREFIX"]}/lib/libcudart.so.11.0') major, minor, revision = (ls_output.split(" ")[-1].replace("libcudart.so.", "").split(".")) version = float(f"{major}.{minor}") From 89e3b82731db66eb4bb0c0690f1f623c8ef6df65 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 11 Apr 2023 13:47:10 -0700 Subject: [PATCH 48/52] Added more detailed cuda setup debug and debugging instructions. --- README.md | 50 ++++++++++++++----- bitsandbytes/__main__.py | 87 ++++++++++++++++++++++++++++----- bitsandbytes/cuda_setup/main.py | 3 +- compile_from_source.md | 27 +++++++--- cuda_install.sh | 4 +- errors_and_solutions.md | 2 +- 6 files changed, 139 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index de6b27b..600401c 100644 --- a/README.md +++ b/README.md @@ -11,27 +11,40 @@ Resources: ## TL;DR **Requirements** -Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0. LLM.int8() requires Turing or Ampere GPUs. +Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0. + +(Deprecated: CUDA 10.0 is deprecated and only CUDA >= 11.0) will be supported with release 0.39.0) **Installation**: ``pip install bitsandbytes`` -In some cases it can happen that you need to compile from source. In that case, you can install CUDA with the install script in the repository. No sudo is required for this install. +In some cases it can happen that you need to compile from source. If this happens please consider submitting a bug report with `python -m bitsandbytes` information. What now follows is some short instructions which might work out of the box if `nvcc` is installed. If these do not work see further below. +Compilation quickstart: ```bash -wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh -# Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH -# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121} -# EXPORT_TO_BASH in {0, 1} with 0=False and 1=True +git clone https://github.com/timdettmers/bitsandbytes.git +cd bitsandbytes -# For example, the following installs CUDA 11.8 to ~/local/cuda-11.8 and exports the path to your .bashrc -bash cuda install 118 ~/local 1 +# CUDA_VERSIONS in {110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 120} +# make argument in {cuda110, cuda11x, cuda12x} +# if you do not know what CUDA you have, try looking at the output of: python -m bitsandbytes +CUDA_VERSION=117 make cuda11x +python setup.py install ``` -To use a specific CUDA version just for a single compile run, you can set the variable `CUDA_HOME`, for example the following command compiles `libbitsandbytes_cuda117.so` using compiler flags for cuda11x with the cuda version at `~/local/cuda-11.7`: +**Using Int8 inference with HuggingFace Transformers** -``CUDA_HOME=~/local/cuda-11.7 CUDA_VERSION=117 make cuda11x`` +```python +from transformers import AutoModelForCausalLM +model = AutoModelForCausalLM.from_pretrained( + 'decapoda-research/llama-7b-hf, + device_map='auto', + load_in_8bit=True, + max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB') +``` + +A more detailed example, can be found in [examples/int8_inference_huggingface.py](examples/int8_inference_huggingface.py). **Using 8-bit optimizer**: 1. Comment out optimizer: ``#torch.optim.Adam(....)`` @@ -130,8 +143,23 @@ For upcoming features and changes and full history see [Patch Notes](CHANGELOG.m 2. __fatbinwrap_.. [Solution](errors_and_solutions.md#fatbinwrap_) ## Compile from source +To compile from source, you need an installation of CUDA. If `nvcc` is not installed, you can install the CUDA Toolkit with nvcc through the following commands. -To compile from source, please follow the [compile_from_source.md](compile_from_source.md) instructions. +```bash +wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh +# Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH +# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121} +# EXPORT_TO_BASH in {0, 1} with 0=False and 1=True + +# For example, the following installs CUDA 11.8 to ~/local/cuda-11.8 and exports the path to your .bashrc +bash cuda install 118 ~/local 1 +``` + +To use a specific CUDA version just for a single compile run, you can set the variable `CUDA_HOME`, for example the following command compiles `libbitsandbytes_cuda117.so` using compiler flags for cuda11x with the cuda version at `~/local/cuda-11.7`: + +``CUDA_HOME=~/local/cuda-11.7 CUDA_VERSION=117 make cuda11x`` + +For more detailed instruction, please follow the [compile_from_source.md](compile_from_source.md) instructions. ## License diff --git a/bitsandbytes/__main__.py b/bitsandbytes/__main__.py index f45fc34..a100b29 100644 --- a/bitsandbytes/__main__.py +++ b/bitsandbytes/__main__.py @@ -1,11 +1,82 @@ import os import sys +import shlex +import subprocess + from warnings import warn +from typing import Tuple +from os.path import isdir import torch HEADER_WIDTH = 60 +def execute_and_return(command_string: str) -> Tuple[str, str]: + def _decode(subprocess_err_out_tuple): + return tuple( + to_decode.decode("UTF-8").strip() + for to_decode in subprocess_err_out_tuple + ) + + def execute_and_return_decoded_std_streams(command_string): + return _decode( + subprocess.Popen( + shlex.split(command_string), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ).communicate() + ) + + std_out, std_err = execute_and_return_decoded_std_streams(command_string) + return std_out, std_err + +def find_file_recursive(folder, filename): + cmd = f'find {folder} -name {filename}' + out, err = execute_and_return(cmd) + if len(err) > 0: + raise RuntimeError('Something when wrong when trying to find file. Maybe you do not have a linux system?') + + return out + + +def generate_bug_report_information(): + print_header("") + print_header("BUG REPORT INFORMATION") + print_header("") + print('') + + if 'CONDA_PREFIX' in os.environ: + paths = find_file_recursive(os.environ['CONDA_PREFIX'], '*cuda*so') + print_header("ANACONDA CUDA PATHS") + print(paths) + print('') + if isdir('/usr/local/'): + paths = find_file_recursive('/usr/local', '*cuda*so') + print_header("/usr/local CUDA PATHS") + print(paths) + print('') + + if isdir(os.getcwd()): + paths = find_file_recursive(os.getcwd(), '*cuda*so') + print_header("WORKING DIRECTORY CUDA PATHS") + print(paths) + print('') + + print_header("LD_LIBRARY CUDA PATHS") + lib_path = os.environ['LD_LIBRARY_PATH'].strip() + for path in set(lib_path.split(':')): + try: + if isdir(path): + print_header(f"{path} CUDA PATHS") + paths = find_file_recursive(path, '*cuda*so') + print(paths) + except: + print(f'Could not read LD_LIBRARY_PATH: {path}') + print('') + + + + def print_header( txt: str, width: int = HEADER_WIDTH, filler: str = "+" @@ -21,25 +92,13 @@ def print_debug_info() -> None: ) -print_header("") -print_header("DEBUG INFORMATION") -print_header("") -print() +generate_bug_report_information() from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL from .cuda_setup.env_vars import to_be_ignored from .cuda_setup.main import get_compute_capabilities, get_cuda_lib_handle -print_header("POTENTIALLY LIBRARY-PATH-LIKE ENV VARS") -for k, v in os.environ.items(): - if "/" in v and not to_be_ignored(k, v): - print(f"'{k}': '{v}'") -print_header("") - -print( - "\nWARNING: Please be sure to sanitize sensible info from any such env vars!\n" -) print_header("OTHER") print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}") @@ -55,6 +114,7 @@ Running a quick check that: + CUDA function is callable """ ) +print("\nWARNING: Please be sure to sanitize sensible info from any such env vars!\n") try: from bitsandbytes.optim import Adam @@ -91,3 +151,4 @@ except Exception as e: print(e) print_debug_info() sys.exit(1) + diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py index 776bee5..2cadbd7 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/cuda_setup/main.py @@ -373,7 +373,8 @@ def evaluate_cuda_setup(): if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': print('') print('='*35 + 'BUG REPORT' + '='*35) - print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues') + print(('Welcome to bitsandbytes. For bug reports, please run\n\npython -m bitsandbytes\n\n'), + ('and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')) print('='*80) if not torch.cuda.is_available(): return 'libbitsandbytes_cpu.so', None, None, None, None diff --git a/compile_from_source.md b/compile_from_source.md index c126341..7edb33f 100644 --- a/compile_from_source.md +++ b/compile_from_source.md @@ -1,20 +1,35 @@ # Compiling from source Basic steps. -1. `make [target]` where `[target]` is among `cuda92, cuda10x, cuda110, cuda11x, cuda12x, cpuonly` -2. `CUDA_VERSION=XXX python setup.py install` +1. `CUDA_VERSION=XXX make [target]` where `[target]` is among `cuda92, cuda10x, cuda110, cuda11x, cuda12x, cpuonly` +2. `python setup.py install` To run these steps you will need to have the nvcc compiler installed that comes with a CUDA installation. If you use anaconda (recommended) then you can figure out which version of CUDA you are using with PyTorch via the command `conda list | grep cudatoolkit`. Then you can install the nvcc compiler by downloading and installing the same CUDA version from the [CUDA toolkit archive](https://developer.nvidia.com/cuda-toolkit-archive). -For your convenience, there is an installation script in the root directory that installs CUDA 11.1 locally and configures it automatically. After installing you should add the `bin` sub-directory to the `$PATH` variable to make the compiler visible to your system. To do this you can add this to your `.bashrc` by executing these commands: +You can install CUDA locally without sudo by following the following steps: + ```bash -echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64/" >> ~/.bashrc -echo "export PATH=$PATH:/usr/local/cuda/bin/" >> ~/.bashrc -source ~/.bashrc +wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh +# Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH +# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121} +# EXPORT_TO_BASH in {0, 1} with 0=False and 1=True + +# For example, the following installs CUDA 11.7 to ~/local/cuda-11.7 and exports the path to your .bashrc +bash cuda install 117 ~/local 1 ``` By default, the Makefile will look at your `CUDA_HOME` environmental variable to find your CUDA version for compiling the library. If this path is not set it is inferred from the path of your `nvcc` compiler. Either `nvcc` needs to be in path for the `CUDA_HOME` variable needs to be set to the CUDA directory root (e.g. `/usr/local/cuda`) in order for compilation to succeed +If you type `nvcc` and it cannot be found, you might need to add to your path or set the CUDA_HOME variable. You can run `python -m bitsandbytes` to find the path to CUDA. For example if `python -m bitsandbytes` shows you the following: +``` +++++++++++++++++++ /usr/local CUDA PATHS +++++++++++++++++++ +/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudart.so +``` +You can set `CUDA_HOME` to `/usr/local/cuda-11.7`. For example, you might be able to compile like this. + +``CUDA_HOME=~/local/cuda-11.7 CUDA_VERSION=117 make cuda11x`` + + If you have problems compiling the library with these instructions from source, please open an issue. diff --git a/cuda_install.sh b/cuda_install.sh index b333f33..2e6c7d1 100644 --- a/cuda_install.sh +++ b/cuda_install.sh @@ -77,8 +77,8 @@ FILE=$(basename $URL) if [[ -n "$CUDA_VERSION" ]]; then echo $URL echo $FILE - #wget $URL - #bash $FILE --no-drm --no-man-page --override --toolkitpath=$BASE_PATH/$FOLDER/ --toolkit --silent + wget $URL + bash $FILE --no-drm --no-man-page --override --toolkitpath=$BASE_PATH/$FOLDER/ --toolkit --silent if [ "$EXPORT_BASHRC" -eq "1" ]; then echo "export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:$BASE_PATH/$FOLDER/lib64" >> ~/.bashrc echo "export PATH=\$PATH:$BASE_PATH/$FOLDER/bin" >> ~/.bashrc diff --git a/errors_and_solutions.md b/errors_and_solutions.md index 5e8b2d2..5b8cbcd 100644 --- a/errors_and_solutions.md +++ b/errors_and_solutions.md @@ -1,6 +1,6 @@ # No kernel image available -This problem arises with the cuda version loaded by bitsandbytes is not supported by your GPU, or if you pytorch CUDA version mismatches. So solve this problem you need to debug ``$LD_LIBRARY_PATH``, ``$CUDA_HOME``, ``$PATH``. You can print these via ``echo $PATH``. You should look for multiple paths to different CUDA versions. This can include versions in your anaconda path, for example ``$HOME/anaconda3/lib``. You can check those versions via ``ls -l $HOME/anaconda3/lib/*cuda*`` or equivalent paths. Look at the CUDA versions of files in these paths. Does it match with ``nvidia-smi``? +This problem arises with the cuda version loaded by bitsandbytes is not supported by your GPU, or if you pytorch CUDA version mismatches. To solve this problem you need to debug ``$LD_LIBRARY_PATH``, ``$CUDA_HOME``, ``$PATH``. You can print these via ``echo $PATH``. You should look for multiple paths to different CUDA versions. This can include versions in your anaconda path, for example ``$HOME/anaconda3/lib``. You can check those versions via ``ls -l $HOME/anaconda3/lib/*cuda*`` or equivalent paths. Look at the CUDA versions of files in these paths. Does it match with ``nvidia-smi``? If you are feeling lucky, you can also try to compile the library from source. This can be still problematic if your PATH variables have multiple cuda versions. As such, it is recommended to figure out path conflicts before you proceed with compilation. From eb1c331c843cd16ad3c5444fcb0a0ddafc87febe Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 11 Apr 2023 15:49:01 -0700 Subject: [PATCH 49/52] Updates README and CHANGELOG. --- CHANGELOG.md | 8 +++++++- README.md | 2 +- compile_from_source.md | 2 +- cuda_install.sh | 2 +- deploy.sh | 4 ++-- 5 files changed, 12 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 66578d5..5399c02 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -204,14 +204,20 @@ Improvements: ### 0.38.0 -#### 8-bit Lion, Load/Store 8-bit layers +#### 8-bit Lion, Load/Store 8-bit Models directly from/to HF Hub Features: - Support for 32 and 8-bit Lion has been added. Thank you @lucidrains - Support for serialization of Linear8bitLt layers (LLM.int8()). This allows to store and load 8-bit weights directly from the HuggingFace Hub. Thank you @myrab + - New bug report features `python -m bitsandbytes` now gives extensive debugging details to debug CUDA setup failures. Bug fixes: - Fixed a bug where some bitsandbytes methods failed in a model-parallel setup on multiple GPUs. Thank you @tonylins + - Fixed a bug where cudart.so libraries could not be found in newer PyTorch releases. + +Improvements: + - Improved the CUDA Setup procedure by doing a more extensive search for CUDA libraries Deprecated: - Devices with compute capability 3.0 (GTX 700s, K10) and 3.2 (Tegra K1, Jetson TK1) are now deprecated and support will be removed in 0.39.0. + - Support for CUDA 10.0 and 10.2 will be removed in bitsandbytes 0.39.0 diff --git a/README.md b/README.md index 600401c..727a86c 100644 --- a/README.md +++ b/README.md @@ -148,7 +148,7 @@ To compile from source, you need an installation of CUDA. If `nvcc` is not insta ```bash wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh # Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH -# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121} +# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121} # EXPORT_TO_BASH in {0, 1} with 0=False and 1=True # For example, the following installs CUDA 11.8 to ~/local/cuda-11.8 and exports the path to your .bashrc diff --git a/compile_from_source.md b/compile_from_source.md index 7edb33f..9d4f89d 100644 --- a/compile_from_source.md +++ b/compile_from_source.md @@ -11,7 +11,7 @@ You can install CUDA locally without sudo by following the following steps: ```bash wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh # Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH -# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121} +# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121} # EXPORT_TO_BASH in {0, 1} with 0=False and 1=True # For example, the following installs CUDA 11.7 to ~/local/cuda-11.7 and exports the path to your .bashrc diff --git a/cuda_install.sh b/cuda_install.sh index 2e6c7d1..678f7ca 100644 --- a/cuda_install.sh +++ b/cuda_install.sh @@ -77,7 +77,7 @@ FILE=$(basename $URL) if [[ -n "$CUDA_VERSION" ]]; then echo $URL echo $FILE - wget $URL + #wget $URL bash $FILE --no-drm --no-man-page --override --toolkitpath=$BASE_PATH/$FOLDER/ --toolkit --silent if [ "$EXPORT_BASHRC" -eq "1" ]; then echo "export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:$BASE_PATH/$FOLDER/lib64" >> ~/.bashrc diff --git a/deploy.sh b/deploy.sh index 44db970..24d6cbf 100644 --- a/deploy.sh +++ b/deploy.sh @@ -10,8 +10,8 @@ if [[ ! -z "${LD_LIBRARY_PATH}" ]]; then fi -module unload cuda -module unload gcc +module unload cuda && echo "no module function available. Probably not on a slurm cluster." +module unload gcc && echo "no module function available. Probably not on a slurm cluster." rm -rf dist build make cleaneggs From 659a7dfc7165b166b7972250c39daa0b90ad501d Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 11 Apr 2023 16:14:29 -0700 Subject: [PATCH 50/52] Fixing #300. --- bitsandbytes/cuda_setup/main.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py index 2cadbd7..3c4e7f3 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/cuda_setup/main.py @@ -212,12 +212,12 @@ def find_cuda_lib_in(paths_list_candidate: str) -> Set[Path]: def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None: if len(results_paths) > 1: warning_msg = ( - f"Found duplicate {CUDA_RUNTIME_LIB} files: {results_paths}.. " + f"Found duplicate {CUDA_RUNTIME_LIBS} files: {results_paths}.. " "We'll flip a coin and try one of these, in order to fail forward.\n" "Either way, this might cause trouble in the future:\n" "If you get `CUDA error: invalid device function` errors, the above " "might be the cause and the solution is to make sure only one " - f"{CUDA_RUNTIME_LIB} in the paths that we search based on your env.") + f"{CUDA_RUNTIME_LIBS} in the paths that we search based on your env.") CUDASetup.get_instance().add_log_entry(warning_msg, is_warning=True) @@ -245,7 +245,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]: return next(iter(conda_cuda_libs)) CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["CONDA_PREFIX"]} did not contain ' - f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...', is_warning=True) + f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', is_warning=True) if "LD_LIBRARY_PATH" in candidate_env_vars: lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"]) @@ -255,7 +255,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]: warn_in_case_of_duplicates(lib_ld_cuda_libs) CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain ' - f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...', is_warning=True) + f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', is_warning=True) remaining_candidate_env_vars = { env_var: value for env_var, value in candidate_env_vars.items() From 7c651012fce87881bb4e194a26af25790cadea4f Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 12 Apr 2023 07:56:52 -0700 Subject: [PATCH 51/52] Added better error message for debugging on CUDA not detected failures. --- bitsandbytes/cextension.py | 12 ++++++++---- setup.py | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 85bef00..a1f1d4c 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -18,16 +18,20 @@ try: CUDASetup.get_instance().generate_instructions() CUDASetup.get_instance().print_log_stack() raise RuntimeError(''' - CUDA Setup failed despite GPU being available. Inspect the CUDA SETUP outputs above to fix your environment! - If you cannot find any issues and suspect a bug, please open an issue with detals about your environment: - https://github.com/TimDettmers/bitsandbytes/issues''') + CUDA Setup failed despite GPU being available. Please run the following command to get more information: + + python -m bitsandbytes + + Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them + to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes + and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''') lib.cadam32bit_g32 lib.get_context.restype = ct.c_void_p lib.get_cusparse.restype = ct.c_void_p COMPILED_WITH_CUDA = True except AttributeError: warn("The installed version of bitsandbytes was compiled without GPU support. " - "8-bit optimizers and GPU quantization are unavailable.") + "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.") COMPILED_WITH_CUDA = False # print the setup details after checking for errors so we do not print twice diff --git a/setup.py b/setup.py index b023c0b..e514463 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ def read(fname): setup( name=f"bitsandbytes", - version=f"0.38.0", + version=f"0.38.0.post2", author="Tim Dettmers", author_email="dettmers@cs.washington.edu", description="8-bit optimizers and matrix multiplication routines.", From 32f8c89201e85f8405ec263d40baeb6daf84c3cb Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 12 Apr 2023 11:27:31 -0700 Subject: [PATCH 52/52] Added missing example folder. --- examples/int8_inference_huggingface.py | 27 ++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 examples/int8_inference_huggingface.py diff --git a/examples/int8_inference_huggingface.py b/examples/int8_inference_huggingface.py new file mode 100644 index 0000000..dc80a44 --- /dev/null +++ b/examples/int8_inference_huggingface.py @@ -0,0 +1,27 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +MAX_NEW_TOKENS = 128 +model_name = 'decapoda-research/llama-7b-hf' + +text = 'Hamburg is in which country?\n' +tokenizer = AutoTokenizer.from_pretrained(model_name) +input_ids = tokenizer(text, return_tensors="pt").input_ids + +free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3) +max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB' + +n_gpus = torch.cuda.device_count() +max_memory = {i: max_memory for i in range(n_gpus)} + +model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map='auto', + load_in_8bit=True, + max_memory=max_memory +) +generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS) +print(tokenizer.decode(generated_ids[0], skip_special_tokens=True)) + + +