diff --git a/CHANGELOG.md b/CHANGELOG.md index 7950044..9a596ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,7 +49,7 @@ Features: Bug fixes: - Fixed a bug where weight decay was incorrectly applied to 32-bit Adam. #13 - Fixed an unsafe use of eval. #8 - - Fixed a bug where the StableEmbedding layer 32-bit optimizer override would not work without registering the whole model first (`bnb.optim.GlobalOptimManager.get_instance().register_parameters(model.parameters())`). #13 #15 + - Fixed a bug where the StableEmbedding layer 32-bit optimizer override would not work without registering the whole model first (`bnb.optim.GlobalOptimManager.get_instance().register_parameters(model.parameters())`). #13 #15 Docs: - Added instructions how to solve "\_\_fatbinwrap_" errors. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7996343..0fae0ac 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -28,4 +28,4 @@ outlined on that page and do not file a public issue. ## License By contributing to bitsandbytes, you agree that your contributions will be licensed -under the LICENSE file in the root directory of this source tree. \ No newline at end of file +under the LICENSE file in the root directory of this source tree. diff --git a/Makefile b/Makefile index 6194fe3..9cc46f5 100644 --- a/Makefile +++ b/Makefile @@ -26,14 +26,14 @@ INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/inclu LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcurand -lcusparse -L $(CONDA_PREFIX)/lib # NVIDIA NVCC compilation flags -COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler -COMPUTE_CAPABILITY += -gencode arch=compute_37,code=sm_37 # Kepler +COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler +COMPUTE_CAPABILITY += -gencode arch=compute_37,code=sm_37 # Kepler COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta -COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta +COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta # CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not CC_CUDA92 := -gencode arch=compute_30,code=sm_30 @@ -58,38 +58,38 @@ CC_cublasLt111 += -gencode arch=compute_86,code=sm_86 all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env - $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) - $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) + $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) cuda10x_nomatmul: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) cuda110_nomatmul: $(BUILD_DIR) env $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) cuda11x_nomatmul: $(BUILD_DIR) env $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) cuda110: $(BUILD_DIR) env $(NVCC) $(CC_cublasLt110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) - $(NVCC) $(CC_cublasLt110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(NVCC) $(CC_cublasLt110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) cuda11x: $(BUILD_DIR) env $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) - $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) cpuonly: $(BUILD_DIR) env @@ -117,7 +117,7 @@ $(ROOT_DIR)/dependencies/cub: cd dependencies/cub; git checkout 1.11.0 clean: - rm build/* + rm build/* cleaneggs: rm -rf *.egg* diff --git a/README.md b/README.md index 7d35a80..0b7286f 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # bitsandbytes -The bitsandbytes is a lightweight wrapper around CUDA custom functions, in particular 8-bit optimizers, matrix multiplication (LLM.int8()), and quantization functions. +The bitsandbytes is a lightweight wrapper around CUDA custom functions, in particular 8-bit optimizers, matrix multiplication (LLM.int8()), and quantization functions. @@ -48,7 +48,7 @@ out = linear(x.to(torch.float16)) Requirements: anaconda, cudatoolkit, pytorch -Hardware requirements: +Hardware requirements: - LLM.int8(): NVIDIA Turing (RTX 20xx; T4) or Ampere GPU (RTX 30xx; A4-A100); (a GPU from 2018 or older). - 8-bit optimizers and quantization: NVIDIA Maxwell GPU or newer (>=GTX 9XX). @@ -87,7 +87,7 @@ Note that by default all parameter tensors with less than 4096 elements are kept ``` # parameter tensors with less than 16384 values are optimized in 32-bit # it is recommended to use multiplies of 4096 -adam = bnb.optim.Adam8bit(model.parameters(), min_8bit_size=16384) +adam = bnb.optim.Adam8bit(model.parameters(), min_8bit_size=16384) ``` ### Change Bits and other Hyperparameters for Individual Parameters diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 6d1177f..041df4b 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from . import cuda_setup, utils from .autograd._functions import ( MatmulLtState, bmm_cublas, @@ -12,7 +13,6 @@ from .autograd._functions import ( ) from .cextension import COMPILED_WITH_CUDA from .nn import modules -from . import cuda_setup, utils if COMPILED_WITH_CUDA: from .optim import adam diff --git a/bitsandbytes/__main__.py b/bitsandbytes/__main__.py index 175a30e..ac7948b 100644 --- a/bitsandbytes/__main__.py +++ b/bitsandbytes/__main__.py @@ -1,6 +1,3 @@ -# from bitsandbytes.debug_cli import cli - -# cli() import os import sys from warnings import warn @@ -31,8 +28,8 @@ print() from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL -from .cuda_setup.main import get_compute_capabilities, get_cuda_lib_handle from .cuda_setup.env_vars import to_be_ignored +from .cuda_setup.main import get_compute_capabilities, get_cuda_lib_handle print_header("POTENTIALLY LIBRARY-PATH-LIKE ENV VARS") for k, v in os.environ.items(): diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 2ddb406..a115437 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -1,12 +1,13 @@ import operator import warnings - -import torch -import bitsandbytes.functional as F - from dataclasses import dataclass from functools import reduce # Required in Python 3 +import torch + +import bitsandbytes.functional as F + + # math.prod not compatible with python < 3.8 def prod(iterable): return reduce(operator.mul, iterable, 1) @@ -15,10 +16,10 @@ tensor = torch.Tensor """ This class pools outlier dimensions across layers. - This is particularly important for small models where outlier features + This is particularly important for small models where outlier features are less systematic and occur with low frequency. """ -class GlobalOutlierPooler(object): +class GlobalOutlierPooler: _instance = None def __init__(self): @@ -49,8 +50,9 @@ class GlobalOutlierPooler(object): class MatMul8bit(torch.autograd.Function): @staticmethod - def forward(ctx, A, B, out=None, quant_type="vector", precision=[8, 8, 8]): - + def forward(ctx, A, B, out=None, quant_type="vector", precision=None): + if precision is None: + precision = [8, 8, 8] if precision[0] != 8: with torch.no_grad(): output = torch.matmul(A, B) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 264e899..d140f4c 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -1,11 +1,11 @@ import ctypes as ct -import torch - from pathlib import Path from warnings import warn +import torch -class CUDASetup(object): + +class CUDASetup: _instance = None def __init__(self): @@ -122,7 +122,7 @@ try: CUDASetup.get_instance().generate_instructions() CUDASetup.get_instance().print_log_stack() raise RuntimeError(''' - CUDA Setup failed despite GPU being available. Inspect the CUDA SETUP outputs aboveto fix your environment! + CUDA Setup failed despite GPU being available. Inspect the CUDA SETUP outputs above to fix your environment! If you cannot find any issues and suspect a bug, please open an issue with detals about your environment: https://github.com/TimDettmers/bitsandbytes/issues''') lib.cadam32bit_g32 diff --git a/bitsandbytes/cuda_setup/__init__.py b/bitsandbytes/cuda_setup/__init__.py index d8ebba8..e781b9d 100644 --- a/bitsandbytes/cuda_setup/__init__.py +++ b/bitsandbytes/cuda_setup/__init__.py @@ -1,2 +1,6 @@ -from .paths import CUDA_RUNTIME_LIB, extract_candidate_paths, determine_cuda_runtime_lib_path from .main import evaluate_cuda_setup +from .paths import ( + CUDA_RUNTIME_LIB, + determine_cuda_runtime_lib_path, + extract_candidate_paths, +) diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py index 6a6bc74..6feead2 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/cuda_setup/main.py @@ -17,11 +17,13 @@ evaluation: """ import ctypes + import torch -from .paths import determine_cuda_runtime_lib_path from bitsandbytes.cextension import CUDASetup +from .paths import determine_cuda_runtime_lib_path + def check_cuda_result(cuda, result_val): # 3. Check for CUDA errors @@ -48,7 +50,7 @@ def get_cuda_version(cuda, cudart_path): minor = (version-(major*1000))//10 if major < 11: - CUDASetup.get_instance().add_log_entry('CUDA SETUP: CUDA version lower than 11 are currenlty not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!') + CUDASetup.get_instance().add_log_entry('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!') return f'{major}{minor}' @@ -129,7 +131,7 @@ def evaluate_cuda_setup(): failure = True cuda_setup.add_log_entry("WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!", is_warning=True) else: - cuda_setup.add_log_entry((f"CUDA SETUP: CUDA runtime path found: {cudart_path}")) + cuda_setup.add_log_entry(f"CUDA SETUP: CUDA runtime path found: {cudart_path}") if cc == '' or cc is None: failure = True diff --git a/bitsandbytes/cuda_setup/paths.py b/bitsandbytes/cuda_setup/paths.py index 3a5e65d..1c100db 100644 --- a/bitsandbytes/cuda_setup/paths.py +++ b/bitsandbytes/cuda_setup/paths.py @@ -1,6 +1,7 @@ import errno from pathlib import Path from typing import Set, Union + from bitsandbytes.cextension import CUDASetup from .env_vars import get_potentially_lib_path_containing_env_vars diff --git a/bitsandbytes/debug_cli.py b/bitsandbytes/debug_cli.py deleted file mode 100644 index 4306bc0..0000000 --- a/bitsandbytes/debug_cli.py +++ /dev/null @@ -1,26 +0,0 @@ -import typer - -cli = typer.Typer() - - -@cli.callback() -def callback(): - """ - Awesome Portal Gun - """ - - -@cli.command() -def shoot(): - """ - Shoot the portal gun - """ - typer.echo("Shooting portal gun") - - -@cli.command() -def load(): - """ - Load the portal gun - """ - typer.echo("Loading portal gun") diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 662e806..95a7c4f 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -3,17 +3,19 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import ctypes as ct +import itertools import operator import random import torch import itertools import math +from functools import reduce # Required in Python 3 from typing import Tuple from torch import Tensor from .cextension import COMPILED_WITH_CUDA, lib -from functools import reduce # Required in Python 3 + # math.prod not compatible with python < 3.8 def prod(iterable): @@ -84,7 +86,7 @@ if COMPILED_WITH_CUDA: ) -class CUBLAS_Context(object): +class CUBLAS_Context: _instance = None def __init__(self): @@ -114,7 +116,7 @@ class CUBLAS_Context(object): return self.context[device.index] -class Cusparse_Context(object): +class Cusparse_Context: _instance = None def __init__(self): @@ -264,13 +266,12 @@ def create_quantile_map(A, total_bits=8): def get_special_format_str(): if not torch.cuda.is_available(): return 'col_turing' - major, minor = torch.cuda.get_device_capability() + major, _minor = torch.cuda.get_device_capability() if major <= 7: return "col_turing" - elif major == 8: + if major == 8: return "col_ampere" - else: - return "col_turing" + return "col_turing" @@ -397,8 +398,6 @@ def nvidia_transform( dim2 = ct.c_int32(shape[2]) ptr = CUBLAS_Context.get_instance().get_context(A.device) - ptrA = get_ptr(A) - ptrOut = get_ptr(out) func(ptr, get_ptr(A), get_ptr(out), dim1, dim2) return out, new_state @@ -1053,7 +1052,7 @@ def histogram_scatter_add_2d( maxdim1 = ct.c_int32(histogram.shape[0]) n = ct.c_int32(index1.numel()) - is_on_gpu([histogram, index1, index2d, source]) + is_on_gpu([histogram, index1, index2, source]) lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n) def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): @@ -1228,7 +1227,7 @@ def igemm( ptr = CUBLAS_Context.get_instance().get_context(A.device) # B^T @ A^T = C^T - # [km, nk -> mn] + # [km, nk -> mn] is_on_gpu([B, A, out]) lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc)) @@ -1512,7 +1511,7 @@ def get_colrow_absmax( return row_stats, col_stats, nnz_block_ptr -class COOSparseTensor(object): +class COOSparseTensor: def __init__(self, rows, cols, nnz, rowidx, colidx, values): assert rowidx.dtype == torch.int32 assert colidx.dtype == torch.int32 @@ -1529,7 +1528,7 @@ class COOSparseTensor(object): self.values = values -class CSRSparseTensor(object): +class CSRSparseTensor: def __init__(self, rows, cols, nnz, rowptr, colidx, values): assert rowptr.dtype == torch.int32 assert colidx.dtype == torch.int32 @@ -1546,7 +1545,7 @@ class CSRSparseTensor(object): self.values = values -class CSCSparseTensor(object): +class CSCSparseTensor: def __init__(self, rows, cols, nnz, colptr, rowidx, values): assert colptr.dtype == torch.int32 assert rowidx.dtype == torch.int32 @@ -1710,8 +1709,6 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No dim1 = ct.c_int32(shape[0] * shape[1]) dim2 = ct.c_int32(shape[2]) - ptrA = get_ptr(A) - ptrOut = get_ptr(out) is_on_gpu([A, out]) if to_order == 'col32': if transpose: diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 4f82cdc..7f4c670 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -2,24 +2,11 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import ( - Any, - Callable, - Dict, - Iterator, - Mapping, - Optional, - Set, - Tuple, - TypeVar, - Union, - overload, -) +from typing import Optional, TypeVar, Union, overload import torch import torch.nn.functional as F from torch import Tensor, device, dtype, nn -from torch.nn.parameter import Parameter import bitsandbytes as bnb from bitsandbytes.optim import GlobalOptimManager @@ -39,7 +26,7 @@ class StableEmbedding(torch.nn.Embedding): sparse: bool = False, _weight: Optional[Tensor] = None, ) -> None: - super(StableEmbedding, self).__init__( + super().__init__( num_embeddings, embedding_dim, padding_idx, @@ -96,7 +83,7 @@ class Embedding(torch.nn.Embedding): sparse: bool = False, _weight: Optional[Tensor] = None, ) -> None: - super(Embedding, self).__init__( + super().__init__( num_embeddings, embedding_dim, padding_idx, @@ -225,7 +212,7 @@ class Linear8bitLt(nn.Linear): threshold=0.0, index=None, ): - super(Linear8bitLt, self).__init__( + super().__init__( input_features, output_features, bias ) self.state = bnb.MatmulLtState() @@ -267,7 +254,7 @@ class Linear8bitLt(nn.Linear): self.weight.data = self.state.CxB elif self.state.memory_efficient_backward and self.state.CxB is not None: # For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass. - # Thus, we delete CxB from the state. + # Thus, we delete CxB from the state. del self.state.CxB return out diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py index d18f1d1..8c8a8f4 100644 --- a/bitsandbytes/optim/__init__.py +++ b/bitsandbytes/optim/__init__.py @@ -5,12 +5,11 @@ from bitsandbytes.cextension import COMPILED_WITH_CUDA +from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit from .adam import Adam, Adam8bit, Adam32bit from .adamw import AdamW, AdamW8bit, AdamW32bit -from .sgd import SGD, SGD8bit, SGD32bit -from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS from .lamb import LAMB, LAMB8bit, LAMB32bit -from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit -from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit - +from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS from .optimizer import GlobalOptimManager +from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit +from .sgd import SGD, SGD8bit, SGD32bit diff --git a/bitsandbytes/optim/adagrad.py b/bitsandbytes/optim/adagrad.py index 7e2f566..7d8df58 100644 --- a/bitsandbytes/optim/adagrad.py +++ b/bitsandbytes/optim/adagrad.py @@ -21,18 +21,18 @@ class Adagrad(Optimizer1State): block_wise=True, ): if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) + raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= weight_decay: raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) + f"Invalid weight_decay value: {weight_decay}" ) if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) + raise ValueError(f"Invalid epsilon value: {eps}") if initial_accumulator_value != 0.0: raise ValueError("Initial accumulator value != 0.0 not supported!") if lr_decay != 0.0: raise ValueError("Lr Decay != 0.0 not supported!") - super(Adagrad, self).__init__( + super().__init__( "adagrad", params, lr, @@ -63,19 +63,19 @@ class Adagrad8bit(Optimizer1State): block_wise=True, ): if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) + raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= weight_decay: raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) + f"Invalid weight_decay value: {weight_decay}" ) if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) + raise ValueError(f"Invalid epsilon value: {eps}") if initial_accumulator_value != 0.0: raise ValueError("Initial accumulator value != 0.0 not supported!") if lr_decay != 0.0: raise ValueError("Lr Decay != 0.0 not supported!") assert block_wise - super(Adagrad8bit, self).__init__( + super().__init__( "adagrad", params, lr, @@ -106,18 +106,18 @@ class Adagrad32bit(Optimizer1State): block_wise=True, ): if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) + raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= weight_decay: raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) + f"Invalid weight_decay value: {weight_decay}" ) if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) + raise ValueError(f"Invalid epsilon value: {eps}") if initial_accumulator_value != 0.0: raise ValueError("Initial accumulator value != 0.0 not supported!") if lr_decay != 0.0: raise ValueError("Lr Decay != 0.0 not supported!") - super(Adagrad32bit, self).__init__( + super().__init__( "adagrad", params, lr, diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index 3634971..396aeb8 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -28,7 +28,7 @@ class Adam(Optimizer2State): percentile_clipping=100, block_wise=True, ): - super(Adam, self).__init__( + super().__init__( "adam", params, lr, @@ -57,7 +57,7 @@ class Adam8bit(Optimizer2State): percentile_clipping=100, block_wise=True, ): - super(Adam8bit, self).__init__( + super().__init__( "adam", params, lr, @@ -86,7 +86,7 @@ class Adam32bit(Optimizer2State): percentile_clipping=100, block_wise=True, ): - super(Adam32bit, self).__init__( + super().__init__( "adam", params, lr, @@ -146,7 +146,7 @@ class AnalysisAdam(torch.optim.Optimizer): weight_decay=weight_decay, amsgrad=amsgrad, ) - super(AnalysisAdam, self).__init__(params, defaults) + super().__init__(params, defaults) self.analysis = bnb_analysis self.savedir = savedir diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py index d0b3bde..022e64c 100644 --- a/bitsandbytes/optim/adamw.py +++ b/bitsandbytes/optim/adamw.py @@ -20,7 +20,7 @@ class AdamW(Optimizer2State): percentile_clipping=100, block_wise=True, ): - super(AdamW, self).__init__( + super().__init__( "adam", params, lr, @@ -49,7 +49,7 @@ class AdamW8bit(Optimizer2State): percentile_clipping=100, block_wise=True, ): - super(AdamW8bit, self).__init__( + super().__init__( "adam", params, lr, @@ -78,7 +78,7 @@ class AdamW32bit(Optimizer2State): percentile_clipping=100, block_wise=True, ): - super(AdamW32bit, self).__init__( + super().__init__( "adam", params, lr, diff --git a/bitsandbytes/optim/lamb.py b/bitsandbytes/optim/lamb.py index 8f365f7..1fbb6fa 100644 --- a/bitsandbytes/optim/lamb.py +++ b/bitsandbytes/optim/lamb.py @@ -23,7 +23,7 @@ class LAMB(Optimizer2State): block_wise=False, max_unorm=1.0, ): - super(LAMB, self).__init__( + super().__init__( "lamb", params, lr, @@ -56,7 +56,7 @@ class LAMB8bit(Optimizer2State): block_wise=False, max_unorm=1.0, ): - super(LAMB8bit, self).__init__( + super().__init__( "lamb", params, lr, @@ -89,7 +89,7 @@ class LAMB32bit(Optimizer2State): block_wise=False, max_unorm=1.0, ): - super(LAMB32bit, self).__init__( + super().__init__( "lamb", params, lr, diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py index 8a89fb0..73554e3 100644 --- a/bitsandbytes/optim/lars.py +++ b/bitsandbytes/optim/lars.py @@ -25,9 +25,9 @@ class LARS(Optimizer1State): ): if momentum == 0: raise NotImplementedError( - f"LARS without momentum is not supported!" + "LARS without momentum is not supported!" ) - super(LARS, self).__init__( + super().__init__( "lars", params, lr, @@ -59,9 +59,9 @@ class LARS8bit(Optimizer1State): ): if momentum == 0: raise NotImplementedError( - f"LARS without momentum is not supported!" + "LARS without momentum is not supported!" ) - super(LARS8bit, self).__init__( + super().__init__( "lars", params, lr, @@ -93,9 +93,9 @@ class LARS32bit(Optimizer1State): ): if momentum == 0: raise NotImplementedError( - f"LARS without momentum is not supported!" + "LARS without momentum is not supported!" ) - super(LARS32bit, self).__init__( + super().__init__( "lars", params, lr, @@ -123,12 +123,12 @@ class PytorchLARS(Optimizer): max_unorm=0.02, ): if lr < 0.0: - raise ValueError("Invalid learning rate: {}".format(lr)) + raise ValueError(f"Invalid learning rate: {lr}") if momentum < 0.0: - raise ValueError("Invalid momentum value: {}".format(momentum)) + raise ValueError(f"Invalid momentum value: {momentum}") if weight_decay < 0.0: raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) + f"Invalid weight_decay value: {weight_decay}" ) defaults = dict( @@ -143,10 +143,10 @@ class PytorchLARS(Optimizer): raise ValueError( "Nesterov momentum requires a momentum and zero dampening" ) - super(PytorchLARS, self).__init__(params, defaults) + super().__init__(params, defaults) def __setstate__(self, state): - super(PytorchLARS, self).__setstate__(state) + super().__setstate__(state) for group in self.param_groups: group.setdefault("nesterov", False) @@ -181,7 +181,7 @@ class PytorchLARS(Optimizer): state = self.state[p] d_p = p.grad if weight_decay != 0: - d_p = d_p.add(param, alpha=weight_decay) + d_p = d_p.add(p, alpha=weight_decay) if momentum != 0: buf = state.get("momentum_buffer", None) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 4fb30cd..867ad3d 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -12,13 +12,13 @@ import torch import bitsandbytes.functional as F -class MockArgs(object): +class MockArgs: def __init__(self, initial_data): for key in initial_data: setattr(self, key, initial_data[key]) -class GlobalOptimManager(object): +class GlobalOptimManager: _instance = None def __init__(self): @@ -56,9 +56,9 @@ class GlobalOptimManager(object): """ Overrides initial optimizer config for specific parameters. - The key-values of the optimizer config for the input parameters are overidden + The key-values of the optimizer config for the input parameters are overridden This can be both, optimizer parameters like "betas", or "lr" or it can be - 8-bit specific paramters like "optim_bits", "percentile_clipping". + 8-bit specific parameters like "optim_bits", "percentile_clipping". Parameters ---------- @@ -93,13 +93,12 @@ class GlobalOptimManager(object): class Optimizer8bit(torch.optim.Optimizer): def __init__(self, params, defaults, optim_bits=32): - super(Optimizer8bit, self).__init__(params, defaults) + super().__init__(params, defaults) self.initialized = False self.name2qmap = {} self.mng = GlobalOptimManager.get_instance() - self.non_castable_tensor_keys = set( - [ + self.non_castable_tensor_keys = { "qmap1", "qmap2", "max1", @@ -112,8 +111,7 @@ class Optimizer8bit(torch.optim.Optimizer): "absmax1", "absmax2", "unorm_vec", - ] - ) + } if optim_bits == 8: self.fill_qmap() @@ -123,7 +121,7 @@ class Optimizer8bit(torch.optim.Optimizer): self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False) def __setstate__(self, state): - super(Optimizer8bit, self).__setstate__(state) + super().__setstate__(state) def load_state_dict(self, state_dict): r"""Loads the optimizer state. @@ -155,8 +153,8 @@ class Optimizer8bit(torch.optim.Optimizer): id_map = { old_id: p for old_id, p in zip( - chain.from_iterable((g["params"] for g in saved_groups)), - chain.from_iterable((g["params"] for g in groups)), + chain.from_iterable(g["params"] for g in saved_groups), + chain.from_iterable(g["params"] for g in groups), ) } @@ -284,11 +282,11 @@ class Optimizer8bit(torch.optim.Optimizer): return config def init_state(self, group, p, gindex, pindex): - raise NotImplementedError(f"init_state method needs to be overidden") + raise NotImplementedError("init_state method needs to be overridden") def update_step(self, group, p, gindex, pindex): raise NotImplementedError( - f"The update_step method needs to be overidden" + "The update_step method needs to be overridden" ) @@ -310,9 +308,9 @@ class Optimizer2State(Optimizer8bit): skip_zeros=False, ): if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) + raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) + raise ValueError(f"Invalid epsilon value: {eps}") if isinstance(betas, str): # format: '(beta1, beta2)' betas = betas.replace("(", "").replace(")", "").strip().split(",") @@ -324,10 +322,10 @@ class Optimizer2State(Optimizer8bit): ) if not 0.0 <= weight_decay: raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) + f"Invalid weight_decay value: {weight_decay}" ) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - super(Optimizer2State, self).__init__(params, defaults, optim_bits) + super().__init__(params, defaults, optim_bits) if args is None: args = {} @@ -542,9 +540,9 @@ class Optimizer1State(Optimizer8bit): skip_zeros=False, ): if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) + raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) + raise ValueError(f"Invalid epsilon value: {eps}") for i in range(len(betas)): if not 0.0 <= betas[i] < 1.0: raise ValueError( @@ -552,10 +550,10 @@ class Optimizer1State(Optimizer8bit): ) if not 0.0 <= weight_decay: raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) + f"Invalid weight_decay value: {weight_decay}" ) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - super(Optimizer1State, self).__init__(params, defaults, optim_bits) + super().__init__(params, defaults, optim_bits) if args is None: args = {} diff --git a/bitsandbytes/optim/rmsprop.py b/bitsandbytes/optim/rmsprop.py index 7ddb12c..2853ca7 100644 --- a/bitsandbytes/optim/rmsprop.py +++ b/bitsandbytes/optim/rmsprop.py @@ -23,11 +23,11 @@ class RMSprop(Optimizer1State): ): if alpha == 0: raise NotImplementedError( - f"RMSprop with alpha==0.0 is not supported!" + "RMSprop with alpha==0.0 is not supported!" ) if centered: - raise NotImplementedError(f"Centered RMSprop is not supported!") - super(RMSprop, self).__init__( + raise NotImplementedError("Centered RMSprop is not supported!") + super().__init__( "rmsprop", params, lr, @@ -59,11 +59,11 @@ class RMSprop8bit(Optimizer1State): ): if alpha == 0: raise NotImplementedError( - f"RMSprop with alpha==0.0 is not supported!" + "RMSprop with alpha==0.0 is not supported!" ) if centered: - raise NotImplementedError(f"Centered RMSprop is not supported!") - super(RMSprop8bit, self).__init__( + raise NotImplementedError("Centered RMSprop is not supported!") + super().__init__( "rmsprop", params, lr, @@ -96,11 +96,11 @@ class RMSprop32bit(Optimizer1State): if alpha == 0: raise NotImplementedError( - f"RMSprop with alpha==0.0 is not supported!" + "RMSprop with alpha==0.0 is not supported!" ) if centered: - raise NotImplementedError(f"Centered RMSprop is not supported!") - super(RMSprop32bit, self).__init__( + raise NotImplementedError("Centered RMSprop is not supported!") + super().__init__( "rmsprop", params, lr, diff --git a/bitsandbytes/optim/sgd.py b/bitsandbytes/optim/sgd.py index f7b8934..3c0fc2b 100644 --- a/bitsandbytes/optim/sgd.py +++ b/bitsandbytes/optim/sgd.py @@ -21,8 +21,8 @@ class SGD(Optimizer1State): block_wise=True, ): if momentum == 0: - raise NotImplementedError(f"SGD without momentum is not supported!") - super(SGD, self).__init__( + raise NotImplementedError("SGD without momentum is not supported!") + super().__init__( "momentum", params, lr, @@ -52,8 +52,8 @@ class SGD8bit(Optimizer1State): block_wise=True, ): if momentum == 0: - raise NotImplementedError(f"SGD without momentum is not supported!") - super(SGD8bit, self).__init__( + raise NotImplementedError("SGD without momentum is not supported!") + super().__init__( "momentum", params, lr, @@ -83,8 +83,8 @@ class SGD32bit(Optimizer1State): block_wise=True, ): if momentum == 0: - raise NotImplementedError(f"SGD without momentum is not supported!") - super(SGD32bit, self).__init__( + raise NotImplementedError("SGD without momentum is not supported!") + super().__init__( "momentum", params, lr, diff --git a/compile_from_source.md b/compile_from_source.md index 71b0c09..2c4a6ad 100644 --- a/compile_from_source.md +++ b/compile_from_source.md @@ -4,7 +4,7 @@ Basic steps. 1. `make [target]` where `[target]` is among `cuda92, cuda10x, cuda110, cuda11x, cpuonly` 2. `CUDA_VERSION=XXX python setup.py install` -To run these steps you will need to have the nvcc compiler installed that comes with a CUDA installation. If you use anaconda (recommended) then you can figure out which version of CUDA you are using with PyTorch via the command `conda list | grep cudatoolkit`. Then you can install the nvcc compiler by downloading and installing the same CUDA version from the [CUDA toolkit archive](https://developer.nvidia.com/cuda-toolkit-archive). +To run these steps you will need to have the nvcc compiler installed that comes with a CUDA installation. If you use anaconda (recommended) then you can figure out which version of CUDA you are using with PyTorch via the command `conda list | grep cudatoolkit`. Then you can install the nvcc compiler by downloading and installing the same CUDA version from the [CUDA toolkit archive](https://developer.nvidia.com/cuda-toolkit-archive). For your convenience, there is an installation script in the root directory that installs CUDA 11.1 locally and configures it automatically. After installing you should add the `bin` sub-directory to the `$PATH` variable to make the compiler visible to your system. To do this you can add this to your `.bashrc` by executing these commands: ```bash @@ -13,7 +13,7 @@ echo "export PATH=$PATH:/usr/local/cuda/bin/" >> ~/.bashrc source ~/.bashrc ``` -By default, the Makefile will look at your `CUDA_HOME` environmental variable to find your CUDA version for compiling the library. If this path is not set it is inferred from the path of your `nvcc` compiler. +By default, the Makefile will look at your `CUDA_HOME` environmental variable to find your CUDA version for compiling the library. If this path is not set it is inferred from the path of your `nvcc` compiler. Either `nvcc` needs to be in path for the `CUDA_HOME` variable needs to be set to the CUDA directory root (e.g. `/usr/local/cuda`) in order for compilation to succeed diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 2081e68..e28e7b2 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -62,7 +62,7 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long for (int i = 0; i < valid_chunks; i++) int err = pthread_join(threads[i], NULL); - + free(threads); for (int i = 0; i < valid_chunks; i++) free(args[i]); diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 29f266a..08b9b44 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1,6 +1,6 @@ -// Copyright (c) Facebook, Inc. and its affiliates. -// -// This source code is licensed under the MIT license found in the +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. #include @@ -303,7 +303,7 @@ __global__ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* ou if(threadIdx.x % 32 < 8) { // offset: 8 values per 256 input values - // + // int offset = BLOCK_SIZE*blockIdx.x*BLOCK_SIZE/32*8; } @@ -574,7 +574,7 @@ __global__ void kDequantize(float *code, unsigned char *A, float *out, const int template __launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) -__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, +__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n) @@ -622,7 +622,7 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p, { switch(OPTIMIZER) { - case ADAM: + case ADAM: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); s1_vals[j] *= correction1; @@ -653,7 +653,7 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p, template __launch_bounds__(TH, 1) -__global__ void kOptimizer32bit2State(T* g, T* p, +__global__ void kOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) @@ -716,7 +716,7 @@ __global__ void kOptimizer32bit2State(T* g, T* p, { switch(OPTIMIZER) { - case ADAM: + case ADAM: if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) { s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); @@ -741,7 +741,7 @@ __global__ void kOptimizer32bit2State(T* g, T* p, template __launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) -__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, +__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, const float beta1, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n) @@ -783,19 +783,19 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, { switch(OPTIMIZER) { - case MOMENTUM: + case MOMENTUM: if(step == 1) s1_vals[j] = (float)g_vals[j]; // state update else s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm break; - case RMSPROP: + case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm break; - case ADAGRAD: + case ADAGRAD: s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm @@ -819,7 +819,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, template __launch_bounds__(TH, 1) -__global__ void kOptimizer32bit1State(T *g, T *p, +__global__ void kOptimizer32bit1State(T *g, T *p, float *state1, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) @@ -882,7 +882,7 @@ __global__ void kOptimizer32bit1State(T *g, T *p, { switch(OPTIMIZER) { - case MOMENTUM: + case MOMENTUM: if(step == 1) s1_vals[j] = (float)g_vals[j]; else @@ -890,11 +890,11 @@ __global__ void kOptimizer32bit1State(T *g, T *p, p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); break; - case RMSPROP: + case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); break; - case ADAGRAD: + case ADAGRAD: s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps); break; @@ -1156,12 +1156,12 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha template __global__ void __launch_bounds__(NUM_THREADS, 2) -kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, +kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, float *unorm, - const float beta1, + const float beta1, const float eps, const int step, - float* __restrict__ const quantiles1, - float* max1, float* new_max1, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, const float weight_decay, const float gnorm_scale, const int n) { @@ -1211,7 +1211,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]; switch(OPTIMIZER) { - case MOMENTUM: + case MOMENTUM: if(step == 1) s1_vals[j] = (float)g_vals[j]; else @@ -1219,7 +1219,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c if(unorm != NULL) local_unorm += s1_vals[j]*s1_vals[j]; break; - case RMSPROP: + case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); break; } @@ -1244,10 +1244,10 @@ template __global__ void kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, const float *unorm, const float max_unorm, const float param_norm, - const float beta1, + const float beta1, const float eps, const int step, const float lr, - float* __restrict__ const quantiles1, - float* max1, float* new_max1, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, float weight_decay, const float gnorm_scale, const int n) { @@ -1313,7 +1313,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, switch(OPTIMIZER) { - case MOMENTUM: + case MOMENTUM: if(step == 1) s1_vals[j] = g_vals[j]; else @@ -1321,7 +1321,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j])); break; - case RMSPROP: + case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps)); break; @@ -1401,7 +1401,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char const float beta1, const float beta2, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, - float* absmax1, float* absmax2, + float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n) { @@ -1545,7 +1545,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char StoreT(temp_storage.storeh).Store(&(p[i]), g_vals, valid_items); // quantizaztion: 2.67/1.70 -> 3.4/3.3 - # pragma unroll N_PER_TH + # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) { c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1)); @@ -1658,16 +1658,16 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char switch(OPTIMIZER) { - case MOMENTUM: + case MOMENTUM: if(step == 1) s1_vals[j] = g_val; else s1_vals[j] = (s1_vals[j]*beta1) + g_val; break; - case RMSPROP: + case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); break; - case ADAGRAD: + case ADAGRAD: s1_vals[j] = s1_vals[j] + (g_val*g_val); break; } @@ -1698,14 +1698,14 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char { switch(OPTIMIZER) { - case MOMENTUM: + case MOMENTUM: p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); break; - case RMSPROP: + case RMSPROP: g_val = g_vals[j]; p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); break; - case ADAGRAD: + case ADAGRAD: g_val = g_vals[j]; p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); break; @@ -1718,7 +1718,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); // quantizaztion: 2.67/1.70 -> 3.4/3.3 - # pragma unroll N_PER_TH + # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) { c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1)); @@ -1895,9 +1895,9 @@ template __global__ void kd { // Strategy: To dequantize we need to load col/row statistics. This can be very expensive - // since different row/col stats need to be loaded with each thread. + // since different row/col stats need to be loaded with each thread. // (1, bad algorithm) Loading 32 items per thread would only occur 1 row load, but this increases register pressure - // and would lead to low global load utilization. + // and would lead to low global load utilization. // (2, bad algorithm) If each thread loads some columns and multiple rows one needs to do lot of row loads // for each thread and this is duplicated by a factor of 32/num-cols-per-thread. // (3, good algorithm) Combining (1) and (2) we use sub-tiles of size 32xk in shared memory per threadblock. @@ -1905,7 +1905,7 @@ template __global__ void kd // We can run for example 32x128 sub-tiles and warp-strided loads of 4 elements so that each thread has // the same col statistic but needs to load 4 row stats from shared memory. To prevent bank conflicts // we use a block-striped shared memory config [1, 31, 63, 95] so no bank conflicts happen during the - // shared memory loads. + // shared memory loads. // data is in 32 column-tile major with tile width 32 columns and numRows rows // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. @@ -2142,7 +2142,7 @@ template we can remove the divisions to speed up compute since outRows is always a multiple of 8 // 256*outRows/8*base_row/32 = outRows*base_row int col_offset = outRows*base_row; @@ -2349,7 +2349,7 @@ template we can remove the divisions to speed up compute since outRows is always a multiple of 8 // 1024*outRows/32*base_row/32 = outRows*base_row int col_offset = outRows*base_row; @@ -2447,7 +2447,7 @@ template +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB) { @@ -2577,7 +2577,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o #pragma unroll num_items for(int k = 0; k < num_items; k++) local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k]; - + reinterpret_cast(out)[idx_val/num_items] = reinterpret_cast(local_valC)[j/num_items]; } else @@ -2591,11 +2591,11 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o idx_col_B += blockDim.x*SPMM_ITEMS; local_idx_col_B_offset += blockDim.x*SPMM_ITEMS; - } + } } template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA) -{ +{ int local_colidx = idx[blockIdx.x]; if(FORMAT==COL_TURING) @@ -2655,7 +2655,7 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * out[out_idx] = val; } } -} +} //============================================================== // TEMPLATE DEFINITIONS diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index cca983b..d90ea13 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -1,6 +1,6 @@ -// Copyright (c) Facebook, Inc. and its affiliates. -// -// This source code is licensed under the MIT license found in the +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. #include @@ -18,49 +18,49 @@ template __global__ template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n); template -__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, +__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n); template -__global__ void kOptimizer32bit2State(T* g, T* p, +__global__ void kOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template -__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, +__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, const float beta1, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n); template -__global__ void kOptimizer32bit1State(T* g, T* p, +__global__ void kOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void -kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, +kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, float *unorm, - const float beta1, - const float eps, const int step, - float* __restrict__ const quantiles1, - float* max1, float* new_max1, + const float beta1, + const float eps, const int step, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, const float weight_decay, const float gnorm_scale, const int n); template __global__ void -kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, +kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, const float *unorm, const float max_unorm, const float param_norm, - const float beta1, - const float eps, const int step, const float lr, - float* __restrict__ const quantiles1, - float* max1, float* new_max1, + const float beta1, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, float weight_decay, const float gnorm_scale, const int n); @@ -70,7 +70,7 @@ __global__ void kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, float *unorm, const float beta1, const float beta2, - const float eps, const int step, + const float eps, const int step, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, const float gnorm_scale, const int n); @@ -81,7 +81,7 @@ __global__ void kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, const float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, - const float eps, const int step, const float lr, + const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, const int n); @@ -121,5 +121,3 @@ template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); #endif - - diff --git a/csrc/ops.cu b/csrc/ops.cu index 30079e6..e770e10 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -1,6 +1,6 @@ -// Copyright (c) Facebook, Inc. and its affiliates. -// -// This source code is licensed under the MIT license found in the +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. #include @@ -241,7 +241,7 @@ void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, in } -void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, +void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, long long int strideA, long long int strideB, long long int strideC, int batchCount) { const int falpha = 1; @@ -351,7 +351,7 @@ template void trans cublasLtOrder_t orderOut = get_order(); int ldA = get_leading_dim(dim1, dim2); int ldOut = get_leading_dim(dim1, dim2); - + cublasLtMatrixLayout_t A_desc = NULL, out_desc = NULL; cublasLtMatrixTransformDesc_t A2Out_desc = NULL; cublasOperation_t opTranspose = CUBLAS_OP_T; @@ -397,7 +397,7 @@ template void transform(cublasLtHandle_t ltHa template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); template void transform(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); -template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { #ifdef NO_CUBLASLT cout << "" << endl; diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 66e3843..31d4dd8 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -1,6 +1,6 @@ -// Copyright (c) Facebook, Inc. and its affiliates. -// -// This source code is licensed under the MIT license found in the +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. @@ -131,7 +131,7 @@ void dequantize(float *code, unsigned char *A, float *out, int n); template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); -template void optimizer32bit(T* g, T* p, +template void optimizer32bit(T* g, T* p, float* state1, float* state2, float *unorm, float max_unorm, float param_norm, float beta1, float beta2, float eps, float weight_decay, int step, float lr, const float gnorm_scale, bool skip_zeros, int n); @@ -139,15 +139,15 @@ template void optimizer32bit(T* g, T* p, template void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2, float *unorm, float max_unorm, float param_norm, float beta1, float beta2, - float eps, int step, float lr, + float eps, int step, float lr, float* quantiles1, float* quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, int n); template void optimizerStatic8bitBlockwise(T* p, T* g, - unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, - float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); template void percentileClipping(T * g, float *gnorm_vec, int step, const int n); @@ -155,7 +155,7 @@ template void percentileClipping(T * g, float *gnorm_vec, int step, void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n); void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); -void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, +void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, long long int strideA, long long int strideB, long long int strideC, int batchCount); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 5bac30e..d8b2290 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -1,6 +1,6 @@ -// Copyright (c) Facebook, Inc. and its affiliates. -// -// This source code is licensed under the MIT license found in the +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. #if BUILD_CUDA @@ -9,7 +9,7 @@ #include // We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary. -// We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to +// We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to // maintain all that boilerplate //=================================================================================== // UNMANGLED CALLS @@ -290,4 +290,3 @@ extern "C" void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } } - diff --git a/cuda_install.sh b/cuda_install.sh index b6c553b..27c164a 100644 --- a/cuda_install.sh +++ b/cuda_install.sh @@ -76,6 +76,3 @@ if [[ -n "$CUDA_VERSION" ]]; then else echo "" fi - - - diff --git a/howto_config_override.md b/howto_config_override.md index 4680776..55b24e3 100644 --- a/howto_config_override.md +++ b/howto_config_override.md @@ -14,16 +14,16 @@ mng.register_parameters(model.parameters()) # 1. register parameters while still model = model.cuda() # use 8-bit optimizer states for all parameters -adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8) +adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8) # 2a. override: the parameter model.fc1.weight now uses 32-bit Adam -mng.override_config(model.fc1.weight, 'optim_bits', 32) +mng.override_config(model.fc1.weight, 'optim_bits', 32) # 2b. override: the two special layers use # sparse optimization + different learning rate + different Adam betas mng.override_config([model.special.weight, model.also_special.weight], - key_value_dict ={'is_sparse': True, 'lr': 1e-5, 'betas'=(0.9, 0.98)}) -``` + key_value_dict ={'is_sparse': True, 'lr': 1e-5, 'betas'=(0.9, 0.98)}) +``` Possible options for the config override are: `betas, eps, weight_decay, lr, optim_bits, min_8bit_size, percentile_clipping, block_wise, max_unorm` For overrides for particular layers we recommend overriding locally in each module. You can do this by passing the module, the parameter, and its attribute name to the GlobalOptimManager: diff --git a/include/Algo-Direct-Common.h b/include/Algo-Direct-Common.h index cf5f0c9..c970849 100644 --- a/include/Algo-Direct-Common.h +++ b/include/Algo-Direct-Common.h @@ -121,7 +121,7 @@ template struct DirectTraits { typedef FVec1 fVec1; - + static void checkH(T scaler, T H_Times_x0, T xN) { union { @@ -177,9 +177,9 @@ struct DirectInfo , cst0(fun_t::cst0(H, x[0])) { myassert(((bws != NULL) && (isAligned(bws,64))), "bucket pointer not allocated or incorrectly aligned"); - + uint32 nb = 1 + fun_t::f(H, cst0, x[n-1]); - + const uint32 npad = Gap-1; const uint32 n_sz = n + npad; // size of padded vector @@ -320,7 +320,7 @@ struct DirectInfo T cst0 = fun_t::cst0(H, px[0]); const uint32 maxIndex = fun_t::f(H, cst0, px[n-1]); buckets.resize(maxIndex + 1); - + data = Data(px, n, H, buckets.begin(), (npad? xi.begin(): NULL)); } diff --git a/include/SIMD.h b/include/SIMD.h index 642b80a..a2ac1a9 100644 --- a/include/SIMD.h +++ b/include/SIMD.h @@ -203,7 +203,7 @@ struct IVec : IVecBase #if 1 // takes 4 cycles __m128i hi = _mm_shuffle_epi32(vec, 2); // 1 cycle - __m128i s = _mm_add_epi32(vec, hi); + __m128i s = _mm_add_epi32(vec, hi); int32 x = _mm_cvtsi128_si32(s); return -x; #else diff --git a/setup.py b/setup.py index b800d38..3077c74 100644 --- a/setup.py +++ b/setup.py @@ -26,9 +26,6 @@ setup( keywords="gpu optimizers optimization 8-bit quantization compression", url="https://github.com/TimDettmers/bitsandbytes", packages=find_packages(), - entry_points={ - "console_scripts": ["debug_cuda = bitsandbytes.debug_cli:cli"], - }, package_data={"": libs}, long_description=read("README.md"), long_description_content_type="text/markdown", diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 40bb441..c67126d 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -1,4 +1,4 @@ -from itertools import product, permutations +from itertools import permutations, product import pytest import torch @@ -27,7 +27,7 @@ str_values = list( ) ) names = [ - "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format( + "dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format( *vals ) for vals in str_values @@ -286,7 +286,7 @@ str_values = list( has_bias ) ) -names = ["dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}_has_bias_{10}".format(*vals) for vals in str_values] +names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_decomp_{}_has_fp16_weights_{}_has_bias_{}".format(*vals) for vals in str_values] @pytest.mark.parametrize( @@ -336,7 +336,7 @@ def test_matmullt( ) bias = None bias2 = None - if has_bias: + if has_bias: bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2]) bias2 = bias.clone() torch.nn.init.xavier_uniform_(B) diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index bf9a003..7edc01f 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -1,13 +1,13 @@ import os -import pytest -import bitsandbytes as bnb - from typing import List, NamedTuple +import pytest + +import bitsandbytes as bnb from bitsandbytes.cuda_setup import ( CUDA_RUNTIME_LIB, - evaluate_cuda_setup, determine_cuda_runtime_lib_path, + evaluate_cuda_setup, extract_candidate_paths, ) diff --git a/tests/test_functional.py b/tests/test_functional.py index 6a65e2d..69c200a 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -28,7 +28,7 @@ def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0): class FFN(torch.nn.Module): def __init__(self, input_features, hidden_size, bias=True): - super(FFN, self).__init__() + super().__init__() self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias) self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias) @@ -42,7 +42,7 @@ class FFN(torch.nn.Module): return x -class Timer(object): +class Timer: def __init__(self): self.starts = {} self.ends = {} @@ -69,7 +69,7 @@ class Timer(object): self.ends.pop(name) if print_ms and name in self.agg: - print("{0} took: {1:.5f}s".format(name, self.agg[name] / 1000.0)) + print(f"{name} took: {self.agg[name] / 1000.0:.5f}s") return self.agg[name] @@ -302,7 +302,7 @@ batched = [False, True] values = list(product(dim1, dim2, methods, batched)) values_names = list(product(dim1, dim2, method_names, batched)) names = [ - "dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals) + "dim1_{}_dim2_{}_quant_{}_batched_{}".format(*vals) for vals in values_names ] @@ -360,7 +360,7 @@ seq_dim = torch.randint(16, 256, size=(n,)).tolist() transpose = [(False, False), (False, True), (True, False), (True, True)] values = list(product(hidden_dim, batch_dim, transpose, seq_dim)) names = [ - "hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}".format(*vals) + "hidden_dim_{}_batch_dim_{},transpose_{}_seq_dim_{}".format(*vals) for vals in values ] @@ -425,7 +425,7 @@ hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist() batch_dim = torch.randint(2, 16, size=(n,)).tolist() values = list(product(seq_dim, hidden_dim, batch_dim)) names = [ - "seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values + "seq_dim{}_hidden_dim{}_batch_dim{}".format(*vals) for vals in values ] @@ -457,7 +457,7 @@ batch_dim = torch.randint(2, 16, size=(n,)).tolist() transpose = [False, True] values = list(product(seq_dim, hidden_dim, batch_dim, transpose)) names = [ - "seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}".format(*vals) + "seq_dim={}_hidden_dim={}_batch_dim={}_transpose{}".format(*vals) for vals in values ] @@ -542,7 +542,7 @@ dim4 = torch.randint(32, 256, size=(n,)).tolist() transpose = [(False, False), (True, False), (False, True), (True, True)] values = list(product(dim1, dim2, dim3, dim4, transpose)) names = [ - "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals) + "dim1_{}_dim2_{}_dim3_{}_dim4_{}_transpose_{}".format(*vals) for vals in values ] @@ -580,7 +580,7 @@ dim1 = torch.randint(1, 64, size=(n,)).tolist() dim2 = torch.randint(32, 128, size=(n,)).tolist() dim3 = torch.randint(32, 256, size=(n,)).tolist() values = list(product(dim1, dim2, dim3)) -names = ["dim1_{0}_dim2_{1}_dim3_{2}".format(*vals) for vals in values] +names = ["dim1_{}_dim2_{}_dim3_{}".format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names) @@ -609,7 +609,7 @@ transpose = [False] dims = [2, 3] values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)) -names = ["dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}".format(*vals)for vals in values] +names = ["dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_transpose_{}".format(*vals)for vals in values] @pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",values,ids=names) @@ -691,7 +691,7 @@ ldb = [0] # ldb = list(range(256, 1*1024, 256)) values = list(product(dim1, dim2, dim3, dim4, dims, ldb)) names = [ - "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}".format(*vals) + "dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}_ldb_{}".format(*vals) for vals in values ] @@ -739,7 +739,7 @@ dims = (2,) # ldb = list(range(256, 1*1024, 256)) values = list(product(dim1, dim2, dim3, dim4, dims)) names = [ - "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals) + "dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}".format(*vals) for vals in values ] @@ -797,7 +797,7 @@ values = [ # values = list(product(batch, seq, model, hidden)) names = [ - "batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values + "batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values ] @@ -965,7 +965,7 @@ dims = (2,) formatB = ["col_turing", "col_ampere"] has_bias = [True, False] values = list(product(dim1, dim4, dims, formatB, has_bias)) -names = ["dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}_has_bias_{4}".format(*vals) for vals in values] +names = ["dim1_{}_dim4_{}_dims_{}_formatB_{}_has_bias_{}".format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names) @@ -1015,7 +1015,7 @@ dim2 = [1 * 1024] dims = (2,) # ldb = list(range(256, 1*1024, 256)) values = list(product(dim1, dim2, dims)) -names = ["dim1_{0}_dim2_{1}_dims_{2}".format(*vals) for vals in values] +names = ["dim1_{}_dim2_{}_dims_{}".format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2, dims", values, ids=names) @@ -1071,7 +1071,7 @@ dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist() values = list(product(dim1, dim2)) -names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values] +names = ["dim1_{}_dim2_{}".format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2", values, ids=names) @@ -1118,7 +1118,7 @@ dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist() inner = torch.randint(1, 4 * 1024, size=(n,)).tolist() values = list(zip(dim1, dim4, inner)) -names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values] +names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) @@ -1162,7 +1162,7 @@ dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist() inner = torch.randint(1, 4 * 1024, size=(n,)).tolist() values = list(zip(dim1, dim4, inner)) -names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values] +names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) @@ -1237,7 +1237,7 @@ inner = [12288 * 4, 4096 * 4] dim4 = [12288, 4096] values = list(zip(dim1, dim4, inner)) -names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values] +names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) @@ -1303,7 +1303,7 @@ values = list( product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose) ) names = [ - "dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}".format( + "dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_{}".format( *vals ) for vals in values @@ -1354,7 +1354,7 @@ a_order = ["col_turing"] out_order = ["row"] values = list(product(dim1, dim2, dtype, a_order, out_order)) names = [ - "dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}".format(*vals) + "dim1_{}_dim2_{}_dtype_{}_orderA_{}_orderOut_{}".format(*vals) for vals in values ] @@ -1380,7 +1380,7 @@ dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist() # dim2 = [5] values = list(product(dim1, dim2)) -names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values] +names = ["dim1_{}_dim2_{}".format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2", values, ids=names) @@ -1417,7 +1417,7 @@ dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist() # dim2 = [11] transposed_B = [False, True] values = list(product(dim1, dim2, transposed_B)) -names = ["dim1_{0}_dim2_{1}_transposed_B_{2}".format(*vals) for vals in values] +names = ["dim1_{}_dim2_{}_transposed_B_{}".format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names) @@ -1498,7 +1498,7 @@ n = 2 dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist() dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist() values = list(product(dim1, dim2)) -names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values] +names = ["dim1_{}_dim2_{}".format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2", values, ids=names) @@ -1563,7 +1563,7 @@ dtype = [torch.float16] out_function = ["zeros", "ones"] values = list(product(dim1, dim2, dtype, out_function)) names = [ - "dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values + "dim1_{}_dim2_{}_dtype_{}_out_func_{}".format(*vals) for vals in values ] @@ -1680,7 +1680,7 @@ dim2 = [2048] # dim2 = [2] dtype = [torch.int8] values = list(product(dim1, dim2, dtype)) -names = ["dim1_{0}_dim2_{1}_dtype_{2}".format(*vals) for vals in values] +names = ["dim1_{}_dim2_{}_dtype_{}".format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names) @@ -1796,7 +1796,7 @@ values.append((batch_size, seqdim, 768, 4 * 768)) # values.append((batch_size, seqdim, 5140, 4*5140)) #values.append((batch_size, seqdim, 12288, 4*12288)) names = [ - "batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values + "batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values ] diff --git a/tests/test_modules.py b/tests/test_modules.py index ccbf670..ffcf304 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -7,7 +7,7 @@ from torch import nn import bitsandbytes as bnb -class MockArgs(object): +class MockArgs: def __init__(self, initial_data): for key in initial_data: setattr(self, key, initial_data[key]) @@ -15,7 +15,7 @@ class MockArgs(object): class MLP8bit(torch.nn.Module): def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0): - super(MLP8bit, self).__init__() + super().__init__() self.fc1 = bnb.nn.Linear8bitLt( dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, threshold=threshold @@ -289,7 +289,7 @@ class LinearFunction(torch.autograd.Function): class Linear8bit(nn.Module): def __init__(self, input_features, output_features, bias=True, args=None): - super(Linear8bit, self).__init__() + super().__init__() self.input_features = input_features self.output_features = output_features self.args = args @@ -312,7 +312,7 @@ class Linear8bit(nn.Module): threshold = [0.0, 3.0] values = threshold -names = ["threshold_{0}".format(vals) for vals in values] +names = [f"threshold_{vals}" for vals in values] @pytest.mark.parametrize("threshold", values, ids=names) @@ -378,7 +378,7 @@ def test_linear8bitlt_accumulated_gradient(): threshold = [0.0, 2.0] values = threshold -names = ["threshold_{0}".format(vals) for vals in values] +names = [f"threshold_{vals}" for vals in values] @pytest.mark.parametrize("threshold", values, ids=names) diff --git a/tests/test_optim.py b/tests/test_optim.py index 80b0802..3df2dad 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -18,7 +18,7 @@ k = 20 def get_temp_dir(): - path = "/tmp/autoswap/{0}".format(str(uuid.uuid4())) + path = f"/tmp/autoswap/{str(uuid.uuid4())}" os.makedirs(path, exist_ok=True) return path @@ -116,7 +116,7 @@ gtype = [torch.float32, torch.float16] optimizer_names = ["adam", "momentum", "rmsprop", "lars"] values = list(product(dim1, dim2, gtype, optimizer_names)) names = [ - "dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values + "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values ] @@ -187,7 +187,7 @@ dim1 = [1024] dim2 = [32, 1024, 4097] gtype = [torch.float32, torch.float16] values = list(product(dim1, dim2, gtype)) -names = ["dim1_{0}_dim2_{1}_gtype_{2}".format(*vals) for vals in values] +names = ["dim1_{}_dim2_{}_gtype_{}".format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names) @@ -250,7 +250,7 @@ optimizer_names = [ ] values = list(product(dim1, dim2, gtype, optimizer_names)) names = [ - "dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values + "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values ] @@ -391,7 +391,7 @@ gtype = [torch.float32] optim_bits = [32, 8] values = list(product(dim1, dim2, gtype, optim_bits)) names = [ - "dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals) + "dim1_{}_dim2_{}_gtype_{}_optim_bits_{}".format(*vals) for vals in values ] @@ -495,7 +495,7 @@ gtype = [torch.float32, torch.float16] optimizer_names = ["adam8bit_blockwise"] values = list(product(dim1, dim2, gtype, optimizer_names)) names = [ - "dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values + "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values ] diff --git a/to_be_fixed__complaints_by_linter.log b/to_be_fixed__complaints_by_linter.log deleted file mode 100644 index d696729..0000000 --- a/to_be_fixed__complaints_by_linter.log +++ /dev/null @@ -1,149 +0,0 @@ -./setup.py:20:10: F541 f-string is missing placeholders -./setup.py:21:13: F541 f-string is missing placeholders -./quicktest.py:5:1: F401 'bitsandbytes as bnb' imported but unused -./bitsandbytes/cuda_setup.py:42:56: F821 undefined name 'error_str' -./bitsandbytes/cuda_setup.py:43:15: F541 f-string is missing placeholders -./bitsandbytes/cuda_setup.py:67:5: F841 local variable 'context' is assigned to but never used -./bitsandbytes/cuda_setup.py:68:5: F841 local variable 'error_str' is assigned to but never used -./bitsandbytes/cuda_setup.py:76:9: F841 local variable 'result' is assigned to but never used -./bitsandbytes/cuda_setup.py:144:13: F841 local variable 'has_gpu' is assigned to but never used -./bitsandbytes/functional.py:294:13: F821 undefined name 'math' -./bitsandbytes/functional.py:295:16: F821 undefined name 'math' -./bitsandbytes/functional.py:303:5: F841 local variable 'ptrA' is assigned to but never used -./bitsandbytes/functional.py:304:5: F841 local variable 'ptrOut' is assigned to but never used -./bitsandbytes/functional.py:1057:17: W503 line break before binary operator -./bitsandbytes/functional.py:1058:17: W503 line break before binary operator -./bitsandbytes/functional.py:1059:17: W503 line break before binary operator -./bitsandbytes/functional.py:1649:1: F811 redefinition of unused 'get_special_format_str' from line 160 -./bitsandbytes/functional.py:1687:5: F841 local variable 'ptrA' is assigned to but never used -./bitsandbytes/functional.py:1688:5: F841 local variable 'ptrOut' is assigned to but never used -./bitsandbytes/functional.py:1802:5: F841 local variable 'ccolsA' is assigned to but never used -./bitsandbytes/functional.py:1805:5: F841 local variable 'cldb' is assigned to but never used -./bitsandbytes/functional.py:1806:5: F841 local variable 'cldc' is assigned to but never used -./bitsandbytes/functional.py:1873:9: F841 local variable 'dtype' is assigned to but never used -./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.MatmulLtState' imported but unused -./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.bmm_cublas' imported but unused -./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.matmul' imported but unused -./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.matmul_cublas' imported but unused -./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.mm_cublas' imported but unused -./bitsandbytes/__init__.py:9:1: F401 '.nn.modules' imported but unused -./bitsandbytes/__init__.py:12:5: F401 '.optim.adam' imported but unused -./bitsandbytes/autograd/_functions.py:5:1: F401 'bitsandbytes as bnb' imported but unused -./bitsandbytes/autograd/_functions.py:12:75: W291 trailing whitespace -./bitsandbytes/nn/__init__.py:5:1: F401 '.modules.Int8Params' imported but unused -./bitsandbytes/nn/__init__.py:5:1: F401 '.modules.Linear8bit' imported but unused -./bitsandbytes/nn/__init__.py:5:1: F401 '.modules.Linear8bitLt' imported but unused -./bitsandbytes/nn/__init__.py:5:1: F401 '.modules.StableEmbedding' imported but unused -./bitsandbytes/nn/modules.py:5:1: F401 'typing.Any' imported but unused -./bitsandbytes/nn/modules.py:5:1: F401 'typing.Callable' imported but unused -./bitsandbytes/nn/modules.py:5:1: F401 'typing.Dict' imported but unused -./bitsandbytes/nn/modules.py:5:1: F401 'typing.Iterator' imported but unused -./bitsandbytes/nn/modules.py:5:1: F401 'typing.Mapping' imported but unused -./bitsandbytes/nn/modules.py:5:1: F401 'typing.Set' imported but unused -./bitsandbytes/nn/modules.py:5:1: F401 'typing.Tuple' imported but unused -./bitsandbytes/nn/modules.py:11:1: F401 'torch.nn.parameter.Parameter' imported but unused -./bitsandbytes/nn/modules.py:183:13: W503 line break before binary operator -./bitsandbytes/nn/modules.py:184:13: W503 line break before binary operator -./bitsandbytes/nn/modules.py:272:24: F821 undefined name 'dist' -./bitsandbytes/nn/modules.py:272:49: F821 undefined name 'dist' -./bitsandbytes/optim/optimizer.py:243:9: F841 local variable 'overflows' is assigned to but never used -./bitsandbytes/optim/optimizer.py:280:35: F541 f-string is missing placeholders -./bitsandbytes/optim/optimizer.py:283:35: F541 f-string is missing placeholders -./bitsandbytes/optim/lars.py:27:39: F541 f-string is missing placeholders -./bitsandbytes/optim/lars.py:59:39: F541 f-string is missing placeholders -./bitsandbytes/optim/lars.py:91:39: F541 f-string is missing placeholders -./bitsandbytes/optim/lars.py:157:13: F841 local variable 'params_with_grad' is assigned to but never used -./bitsandbytes/optim/lars.py:158:13: F841 local variable 'd_p_list' is assigned to but never used -./bitsandbytes/optim/lars.py:159:13: F841 local variable 'momentum_buffer_list' is assigned to but never used -./bitsandbytes/optim/lars.py:174:35: F821 undefined name 'param' -./bitsandbytes/optim/__init__.py:9:5: F401 '.adam.Adam' imported but unused -./bitsandbytes/optim/__init__.py:9:5: F401 '.adam.Adam8bit' imported but unused -./bitsandbytes/optim/__init__.py:9:5: F401 '.adam.Adam32bit' imported but unused -./bitsandbytes/optim/__init__.py:10:5: F401 '.adamw.AdamW' imported but unused -./bitsandbytes/optim/__init__.py:10:5: F401 '.adamw.AdamW8bit' imported but unused -./bitsandbytes/optim/__init__.py:10:5: F401 '.adamw.AdamW32bit' imported but unused -./bitsandbytes/optim/__init__.py:11:5: F401 '.sgd.SGD' imported but unused -./bitsandbytes/optim/__init__.py:11:5: F401 '.sgd.SGD8bit' imported but unused -./bitsandbytes/optim/__init__.py:11:5: F401 '.sgd.SGD32bit' imported but unused -./bitsandbytes/optim/__init__.py:12:5: F401 '.lars.LARS' imported but unused -./bitsandbytes/optim/__init__.py:12:5: F401 '.lars.LARS8bit' imported but unused -./bitsandbytes/optim/__init__.py:12:5: F401 '.lars.LARS32bit' imported but unused -./bitsandbytes/optim/__init__.py:12:5: F401 '.lars.PytorchLARS' imported but unused -./bitsandbytes/optim/__init__.py:13:5: F401 '.lamb.LAMB' imported but unused -./bitsandbytes/optim/__init__.py:13:5: F401 '.lamb.LAMB8bit' imported but unused -./bitsandbytes/optim/__init__.py:13:5: F401 '.lamb.LAMB32bit' imported but unused -./bitsandbytes/optim/__init__.py:14:5: F401 '.rmsprop.RMSprop' imported but unused -./bitsandbytes/optim/__init__.py:14:5: F401 '.rmsprop.RMSprop8bit' imported but unused -./bitsandbytes/optim/__init__.py:14:5: F401 '.rmsprop.RMSprop32bit' imported but unused -./bitsandbytes/optim/__init__.py:15:5: F401 '.adagrad.Adagrad' imported but unused -./bitsandbytes/optim/__init__.py:15:5: F401 '.adagrad.Adagrad8bit' imported but unused -./bitsandbytes/optim/__init__.py:15:5: F401 '.adagrad.Adagrad32bit' imported but unused -./bitsandbytes/optim/__init__.py:17:1: F401 '.optimizer.GlobalOptimManager' imported but unused -./bitsandbytes/optim/adam.py:229:21: F841 local variable 'max_exp_avg_sq' is assigned to but never used -./bitsandbytes/optim/rmsprop.py:25:39: F541 f-string is missing placeholders -./bitsandbytes/optim/rmsprop.py:27:39: F541 f-string is missing placeholders -./bitsandbytes/optim/rmsprop.py:59:39: F541 f-string is missing placeholders -./bitsandbytes/optim/rmsprop.py:61:39: F541 f-string is missing placeholders -./bitsandbytes/optim/rmsprop.py:94:39: F541 f-string is missing placeholders -./bitsandbytes/optim/rmsprop.py:96:39: F541 f-string is missing placeholders -./bitsandbytes/optim/sgd.py:24:39: F541 f-string is missing placeholders -./bitsandbytes/optim/sgd.py:55:39: F541 f-string is missing placeholders -./bitsandbytes/optim/sgd.py:86:39: F541 f-string is missing placeholders -./tests/test_optim.py:1:1: F401 'ctypes' imported but unused -./tests/test_optim.py:199:5: F841 local variable 'mask' is assigned to but never used -./tests/test_optim.py:218:9: F841 local variable 'atol' is assigned to but never used -./tests/test_optim.py:218:15: F841 local variable 'rtol' is assigned to but never used -./tests/test_optim.py:304:17: W503 line break before binary operator -./tests/test_optim.py:354:21: W503 line break before binary operator -./tests/test_autograd.py:309:13: F841 local variable 'err' is assigned to but never used -./tests/test_cuda_setup_evaluator.py:31:9: F821 undefined name 'test_dir' -./tests/test_cuda_setup_evaluator.py:33:14: F821 undefined name 'test_input' -./tests/test_cuda_setup_evaluator.py:81:32: E203 whitespace before ':' -./tests/test_functional.py:55:13: F841 local variable 'ms' is assigned to but never used -./tests/test_functional.py:177:5: F841 local variable 'diffs' is assigned to but never used -./tests/test_functional.py:178:5: F841 local variable 'reldiffs' is assigned to but never used -./tests/test_functional.py:260:5: F841 local variable 'minA' is assigned to but never used -./tests/test_functional.py:261:5: F841 local variable 'maxA' is assigned to but never used -./tests/test_functional.py:584:5: F841 local variable 'func' is assigned to but never used -./tests/test_functional.py:617:17: F841 local variable 'offset' is assigned to but never used -./tests/test_functional.py:618:17: F841 local variable 'col2' is assigned to but never used -./tests/test_functional.py:619:17: F841 local variable 'row2' is assigned to but never used -./tests/test_functional.py:705:9: F841 local variable 'C1' is assigned to but never used -./tests/test_functional.py:706:9: F841 local variable 'C2' is assigned to but never used -./tests/test_functional.py:715:9: F841 local variable 'output' is assigned to but never used -./tests/test_functional.py:750:5: F841 local variable 'formatB' is assigned to but never used -./tests/test_functional.py:754:5: F841 local variable 'w2' is assigned to but never used -./tests/test_functional.py:763:5: F841 local variable 'dtype' is assigned to but never used -./tests/test_functional.py:770:9: F841 local variable 'out1' is assigned to but never used -./tests/test_functional.py:1108:5: F841 local variable 'relerr1' is assigned to but never used -./tests/test_functional.py:1108:14: F841 local variable 'relerr2' is assigned to but never used -./tests/test_functional.py:1114:9: F841 local variable 'C1' is assigned to but never used -./tests/test_functional.py:1135:9: F841 local variable 'C4' is assigned to but never used -./tests/test_functional.py:1179:5: F841 local variable 'err1' is assigned to but never used -./tests/test_functional.py:1179:11: F841 local variable 'err2' is assigned to but never used -./tests/test_functional.py:1179:17: F841 local variable 'err3' is assigned to but never used -./tests/test_functional.py:1180:5: F841 local variable 'relerr1' is assigned to but never used -./tests/test_functional.py:1180:14: F841 local variable 'relerr2' is assigned to but never used -./tests/test_functional.py:1192:9: F841 local variable 'C1' is assigned to but never used -./tests/test_functional.py:1313:9: F841 local variable 'c' is assigned to but never used -./tests/test_functional.py:1314:9: F841 local variable 'c2' is assigned to but never used -./tests/test_functional.py:1406:9: F841 local variable 'C1' is assigned to but never used -./tests/test_functional.py:1425:9: F841 local variable 'out2' is assigned to but never used -./tests/test_functional.py:1542:5: F841 local variable 'idx_col' is assigned to but never used -./tests/test_functional.py:1566:30: E203 whitespace before ':' -./tests/test_functional.py:1568:38: E203 whitespace before ':' -./tests/test_functional.py:1655:5: F841 local variable 'offset' is assigned to but never used -./tests/test_functional.py:1706:9: F841 local variable 'out' is assigned to but never used -./tests/test_functional.py:1822:9: F841 local variable 'out' is assigned to but never used -./tests/test_functional.py:1882:5: F841 local variable 'out2' is assigned to but never used -./tests/test_functional.py:1928:9: F841 local variable 'dtype' is assigned to but never used -./tests/test_functional.py:1934:9: F841 local variable 'minx' is assigned to but never used -./tests/test_functional.py:1948:5: F841 local variable 'C0' is assigned to but never used -./tests/test_modules.py:1:1: F401 'itertools.product' imported but unused -./tests/test_modules.py:52:9: F841 local variable 'norm' is assigned to but never used -./tests/test_modules.py:52:16: F821 undefined name 'math' -./tests/test_modules.py:52:26: F821 undefined name 'math' -./tests/test_modules.py:52:37: F821 undefined name 'math' -./tests/test_modules.py:177:21: F821 undefined name 'einops' -./tests/test_modules.py:233:9: F841 local variable 'stochastic' is assigned to but never used -./tests/test_modules.py:382:9: F841 local variable 'o1' is assigned to but never used