Merge pull request #76 from tomaarsen/cleanup

Cleanup involving a handful of failures, some optimization and a lot of code quality improvements
This commit is contained in:
Tim Dettmers 2023-01-02 11:19:28 +01:00 committed by GitHub
commit f0ec93d016
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
41 changed files with 281 additions and 478 deletions

View File

@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from . import cuda_setup, utils
from .autograd._functions import ( from .autograd._functions import (
MatmulLtState, MatmulLtState,
bmm_cublas, bmm_cublas,
@ -12,7 +13,6 @@ from .autograd._functions import (
) )
from .cextension import COMPILED_WITH_CUDA from .cextension import COMPILED_WITH_CUDA
from .nn import modules from .nn import modules
from . import cuda_setup, utils
if COMPILED_WITH_CUDA: if COMPILED_WITH_CUDA:
from .optim import adam from .optim import adam

View File

@ -1,6 +1,3 @@
# from bitsandbytes.debug_cli import cli
# cli()
import os import os
import sys import sys
from warnings import warn from warnings import warn
@ -31,8 +28,8 @@ print()
from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL
from .cuda_setup.main import get_compute_capabilities, get_cuda_lib_handle
from .cuda_setup.env_vars import to_be_ignored from .cuda_setup.env_vars import to_be_ignored
from .cuda_setup.main import get_compute_capabilities, get_cuda_lib_handle
print_header("POTENTIALLY LIBRARY-PATH-LIKE ENV VARS") print_header("POTENTIALLY LIBRARY-PATH-LIKE ENV VARS")
for k, v in os.environ.items(): for k, v in os.environ.items():

View File

@ -1,12 +1,13 @@
import operator import operator
import warnings import warnings
import torch
import bitsandbytes.functional as F
from dataclasses import dataclass from dataclasses import dataclass
from functools import reduce # Required in Python 3 from functools import reduce # Required in Python 3
import torch
import bitsandbytes.functional as F
# math.prod not compatible with python < 3.8 # math.prod not compatible with python < 3.8
def prod(iterable): def prod(iterable):
return reduce(operator.mul, iterable, 1) return reduce(operator.mul, iterable, 1)
@ -18,7 +19,7 @@ tensor = torch.Tensor
This is particularly important for small models where outlier features This is particularly important for small models where outlier features
are less systematic and occur with low frequency. are less systematic and occur with low frequency.
""" """
class GlobalOutlierPooler(object): class GlobalOutlierPooler:
_instance = None _instance = None
def __init__(self): def __init__(self):
@ -49,8 +50,9 @@ class GlobalOutlierPooler(object):
class MatMul8bit(torch.autograd.Function): class MatMul8bit(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, A, B, out=None, quant_type="vector", precision=[8, 8, 8]): def forward(ctx, A, B, out=None, quant_type="vector", precision=None):
if precision is None:
precision = [8, 8, 8]
if precision[0] != 8: if precision[0] != 8:
with torch.no_grad(): with torch.no_grad():
output = torch.matmul(A, B) output = torch.matmul(A, B)

View File

@ -1,11 +1,11 @@
import ctypes as ct import ctypes as ct
import torch
from pathlib import Path from pathlib import Path
from warnings import warn from warnings import warn
import torch
class CUDASetup(object):
class CUDASetup:
_instance = None _instance = None
def __init__(self): def __init__(self):

View File

@ -1,2 +1,6 @@
from .paths import CUDA_RUNTIME_LIB, extract_candidate_paths, determine_cuda_runtime_lib_path
from .main import evaluate_cuda_setup from .main import evaluate_cuda_setup
from .paths import (
CUDA_RUNTIME_LIB,
determine_cuda_runtime_lib_path,
extract_candidate_paths,
)

View File

@ -17,11 +17,13 @@ evaluation:
""" """
import ctypes import ctypes
import torch import torch
from .paths import determine_cuda_runtime_lib_path
from bitsandbytes.cextension import CUDASetup from bitsandbytes.cextension import CUDASetup
from .paths import determine_cuda_runtime_lib_path
def check_cuda_result(cuda, result_val): def check_cuda_result(cuda, result_val):
# 3. Check for CUDA errors # 3. Check for CUDA errors
@ -48,7 +50,7 @@ def get_cuda_version(cuda, cudart_path):
minor = (version-(major*1000))//10 minor = (version-(major*1000))//10
if major < 11: if major < 11:
CUDASetup.get_instance().add_log_entry('CUDA SETUP: CUDA version lower than 11 are currenlty not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!') CUDASetup.get_instance().add_log_entry('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!')
return f'{major}{minor}' return f'{major}{minor}'
@ -129,7 +131,7 @@ def evaluate_cuda_setup():
failure = True failure = True
cuda_setup.add_log_entry("WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!", is_warning=True) cuda_setup.add_log_entry("WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!", is_warning=True)
else: else:
cuda_setup.add_log_entry((f"CUDA SETUP: CUDA runtime path found: {cudart_path}")) cuda_setup.add_log_entry(f"CUDA SETUP: CUDA runtime path found: {cudart_path}")
if cc == '' or cc is None: if cc == '' or cc is None:
failure = True failure = True

View File

@ -1,6 +1,7 @@
import errno import errno
from pathlib import Path from pathlib import Path
from typing import Set, Union from typing import Set, Union
from bitsandbytes.cextension import CUDASetup from bitsandbytes.cextension import CUDASetup
from .env_vars import get_potentially_lib_path_containing_env_vars from .env_vars import get_potentially_lib_path_containing_env_vars

View File

@ -1,26 +0,0 @@
import typer
cli = typer.Typer()
@cli.callback()
def callback():
"""
Awesome Portal Gun
"""
@cli.command()
def shoot():
"""
Shoot the portal gun
"""
typer.echo("Shooting portal gun")
@cli.command()
def load():
"""
Load the portal gun
"""
typer.echo("Loading portal gun")

View File

@ -3,17 +3,19 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import ctypes as ct import ctypes as ct
import itertools
import operator import operator
import random import random
import torch import torch
import itertools import itertools
import math import math
from functools import reduce # Required in Python 3
from typing import Tuple from typing import Tuple
from torch import Tensor from torch import Tensor
from .cextension import COMPILED_WITH_CUDA, lib from .cextension import COMPILED_WITH_CUDA, lib
from functools import reduce # Required in Python 3
# math.prod not compatible with python < 3.8 # math.prod not compatible with python < 3.8
def prod(iterable): def prod(iterable):
@ -84,7 +86,7 @@ if COMPILED_WITH_CUDA:
) )
class CUBLAS_Context(object): class CUBLAS_Context:
_instance = None _instance = None
def __init__(self): def __init__(self):
@ -114,7 +116,7 @@ class CUBLAS_Context(object):
return self.context[device.index] return self.context[device.index]
class Cusparse_Context(object): class Cusparse_Context:
_instance = None _instance = None
def __init__(self): def __init__(self):
@ -264,12 +266,11 @@ def create_quantile_map(A, total_bits=8):
def get_special_format_str(): def get_special_format_str():
if not torch.cuda.is_available(): return 'col_turing' if not torch.cuda.is_available(): return 'col_turing'
major, minor = torch.cuda.get_device_capability() major, _minor = torch.cuda.get_device_capability()
if major <= 7: if major <= 7:
return "col_turing" return "col_turing"
elif major == 8: if major == 8:
return "col_ampere" return "col_ampere"
else:
return "col_turing" return "col_turing"
@ -397,8 +398,6 @@ def nvidia_transform(
dim2 = ct.c_int32(shape[2]) dim2 = ct.c_int32(shape[2])
ptr = CUBLAS_Context.get_instance().get_context(A.device) ptr = CUBLAS_Context.get_instance().get_context(A.device)
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
func(ptr, get_ptr(A), get_ptr(out), dim1, dim2) func(ptr, get_ptr(A), get_ptr(out), dim1, dim2)
return out, new_state return out, new_state
@ -1053,7 +1052,7 @@ def histogram_scatter_add_2d(
maxdim1 = ct.c_int32(histogram.shape[0]) maxdim1 = ct.c_int32(histogram.shape[0])
n = ct.c_int32(index1.numel()) n = ct.c_int32(index1.numel())
is_on_gpu([histogram, index1, index2d, source]) is_on_gpu([histogram, index1, index2, source])
lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n) lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
@ -1512,7 +1511,7 @@ def get_colrow_absmax(
return row_stats, col_stats, nnz_block_ptr return row_stats, col_stats, nnz_block_ptr
class COOSparseTensor(object): class COOSparseTensor:
def __init__(self, rows, cols, nnz, rowidx, colidx, values): def __init__(self, rows, cols, nnz, rowidx, colidx, values):
assert rowidx.dtype == torch.int32 assert rowidx.dtype == torch.int32
assert colidx.dtype == torch.int32 assert colidx.dtype == torch.int32
@ -1529,7 +1528,7 @@ class COOSparseTensor(object):
self.values = values self.values = values
class CSRSparseTensor(object): class CSRSparseTensor:
def __init__(self, rows, cols, nnz, rowptr, colidx, values): def __init__(self, rows, cols, nnz, rowptr, colidx, values):
assert rowptr.dtype == torch.int32 assert rowptr.dtype == torch.int32
assert colidx.dtype == torch.int32 assert colidx.dtype == torch.int32
@ -1546,7 +1545,7 @@ class CSRSparseTensor(object):
self.values = values self.values = values
class CSCSparseTensor(object): class CSCSparseTensor:
def __init__(self, rows, cols, nnz, colptr, rowidx, values): def __init__(self, rows, cols, nnz, colptr, rowidx, values):
assert colptr.dtype == torch.int32 assert colptr.dtype == torch.int32
assert rowidx.dtype == torch.int32 assert rowidx.dtype == torch.int32
@ -1710,8 +1709,6 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
dim1 = ct.c_int32(shape[0] * shape[1]) dim1 = ct.c_int32(shape[0] * shape[1])
dim2 = ct.c_int32(shape[2]) dim2 = ct.c_int32(shape[2])
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
is_on_gpu([A, out]) is_on_gpu([A, out])
if to_order == 'col32': if to_order == 'col32':
if transpose: if transpose:

View File

@ -2,24 +2,11 @@
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import ( from typing import Optional, TypeVar, Union, overload
Any,
Callable,
Dict,
Iterator,
Mapping,
Optional,
Set,
Tuple,
TypeVar,
Union,
overload,
)
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor, device, dtype, nn from torch import Tensor, device, dtype, nn
from torch.nn.parameter import Parameter
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.optim import GlobalOptimManager
@ -39,7 +26,7 @@ class StableEmbedding(torch.nn.Embedding):
sparse: bool = False, sparse: bool = False,
_weight: Optional[Tensor] = None, _weight: Optional[Tensor] = None,
) -> None: ) -> None:
super(StableEmbedding, self).__init__( super().__init__(
num_embeddings, num_embeddings,
embedding_dim, embedding_dim,
padding_idx, padding_idx,
@ -96,7 +83,7 @@ class Embedding(torch.nn.Embedding):
sparse: bool = False, sparse: bool = False,
_weight: Optional[Tensor] = None, _weight: Optional[Tensor] = None,
) -> None: ) -> None:
super(Embedding, self).__init__( super().__init__(
num_embeddings, num_embeddings,
embedding_dim, embedding_dim,
padding_idx, padding_idx,
@ -225,7 +212,7 @@ class Linear8bitLt(nn.Linear):
threshold=0.0, threshold=0.0,
index=None, index=None,
): ):
super(Linear8bitLt, self).__init__( super().__init__(
input_features, output_features, bias input_features, output_features, bias
) )
self.state = bnb.MatmulLtState() self.state = bnb.MatmulLtState()

View File

@ -5,12 +5,11 @@
from bitsandbytes.cextension import COMPILED_WITH_CUDA from bitsandbytes.cextension import COMPILED_WITH_CUDA
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
from .adam import Adam, Adam8bit, Adam32bit from .adam import Adam, Adam8bit, Adam32bit
from .adamw import AdamW, AdamW8bit, AdamW32bit from .adamw import AdamW, AdamW8bit, AdamW32bit
from .sgd import SGD, SGD8bit, SGD32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .lamb import LAMB, LAMB8bit, LAMB32bit from .lamb import LAMB, LAMB8bit, LAMB32bit
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
from .optimizer import GlobalOptimManager from .optimizer import GlobalOptimManager
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
from .sgd import SGD, SGD8bit, SGD32bit

View File

@ -21,18 +21,18 @@ class Adagrad(Optimizer1State):
block_wise=True, block_wise=True,
): ):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError( raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay) f"Invalid weight_decay value: {weight_decay}"
) )
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError(f"Invalid epsilon value: {eps}")
if initial_accumulator_value != 0.0: if initial_accumulator_value != 0.0:
raise ValueError("Initial accumulator value != 0.0 not supported!") raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0: if lr_decay != 0.0:
raise ValueError("Lr Decay != 0.0 not supported!") raise ValueError("Lr Decay != 0.0 not supported!")
super(Adagrad, self).__init__( super().__init__(
"adagrad", "adagrad",
params, params,
lr, lr,
@ -63,19 +63,19 @@ class Adagrad8bit(Optimizer1State):
block_wise=True, block_wise=True,
): ):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError( raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay) f"Invalid weight_decay value: {weight_decay}"
) )
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError(f"Invalid epsilon value: {eps}")
if initial_accumulator_value != 0.0: if initial_accumulator_value != 0.0:
raise ValueError("Initial accumulator value != 0.0 not supported!") raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0: if lr_decay != 0.0:
raise ValueError("Lr Decay != 0.0 not supported!") raise ValueError("Lr Decay != 0.0 not supported!")
assert block_wise assert block_wise
super(Adagrad8bit, self).__init__( super().__init__(
"adagrad", "adagrad",
params, params,
lr, lr,
@ -106,18 +106,18 @@ class Adagrad32bit(Optimizer1State):
block_wise=True, block_wise=True,
): ):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError( raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay) f"Invalid weight_decay value: {weight_decay}"
) )
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError(f"Invalid epsilon value: {eps}")
if initial_accumulator_value != 0.0: if initial_accumulator_value != 0.0:
raise ValueError("Initial accumulator value != 0.0 not supported!") raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0: if lr_decay != 0.0:
raise ValueError("Lr Decay != 0.0 not supported!") raise ValueError("Lr Decay != 0.0 not supported!")
super(Adagrad32bit, self).__init__( super().__init__(
"adagrad", "adagrad",
params, params,
lr, lr,

View File

@ -28,7 +28,7 @@ class Adam(Optimizer2State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
super(Adam, self).__init__( super().__init__(
"adam", "adam",
params, params,
lr, lr,
@ -57,7 +57,7 @@ class Adam8bit(Optimizer2State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
super(Adam8bit, self).__init__( super().__init__(
"adam", "adam",
params, params,
lr, lr,
@ -86,7 +86,7 @@ class Adam32bit(Optimizer2State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
super(Adam32bit, self).__init__( super().__init__(
"adam", "adam",
params, params,
lr, lr,
@ -146,7 +146,7 @@ class AnalysisAdam(torch.optim.Optimizer):
weight_decay=weight_decay, weight_decay=weight_decay,
amsgrad=amsgrad, amsgrad=amsgrad,
) )
super(AnalysisAdam, self).__init__(params, defaults) super().__init__(params, defaults)
self.analysis = bnb_analysis self.analysis = bnb_analysis
self.savedir = savedir self.savedir = savedir

View File

@ -20,7 +20,7 @@ class AdamW(Optimizer2State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
super(AdamW, self).__init__( super().__init__(
"adam", "adam",
params, params,
lr, lr,
@ -49,7 +49,7 @@ class AdamW8bit(Optimizer2State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
super(AdamW8bit, self).__init__( super().__init__(
"adam", "adam",
params, params,
lr, lr,
@ -78,7 +78,7 @@ class AdamW32bit(Optimizer2State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
super(AdamW32bit, self).__init__( super().__init__(
"adam", "adam",
params, params,
lr, lr,

View File

@ -23,7 +23,7 @@ class LAMB(Optimizer2State):
block_wise=False, block_wise=False,
max_unorm=1.0, max_unorm=1.0,
): ):
super(LAMB, self).__init__( super().__init__(
"lamb", "lamb",
params, params,
lr, lr,
@ -56,7 +56,7 @@ class LAMB8bit(Optimizer2State):
block_wise=False, block_wise=False,
max_unorm=1.0, max_unorm=1.0,
): ):
super(LAMB8bit, self).__init__( super().__init__(
"lamb", "lamb",
params, params,
lr, lr,
@ -89,7 +89,7 @@ class LAMB32bit(Optimizer2State):
block_wise=False, block_wise=False,
max_unorm=1.0, max_unorm=1.0,
): ):
super(LAMB32bit, self).__init__( super().__init__(
"lamb", "lamb",
params, params,
lr, lr,

View File

@ -25,9 +25,9 @@ class LARS(Optimizer1State):
): ):
if momentum == 0: if momentum == 0:
raise NotImplementedError( raise NotImplementedError(
f"LARS without momentum is not supported!" "LARS without momentum is not supported!"
) )
super(LARS, self).__init__( super().__init__(
"lars", "lars",
params, params,
lr, lr,
@ -59,9 +59,9 @@ class LARS8bit(Optimizer1State):
): ):
if momentum == 0: if momentum == 0:
raise NotImplementedError( raise NotImplementedError(
f"LARS without momentum is not supported!" "LARS without momentum is not supported!"
) )
super(LARS8bit, self).__init__( super().__init__(
"lars", "lars",
params, params,
lr, lr,
@ -93,9 +93,9 @@ class LARS32bit(Optimizer1State):
): ):
if momentum == 0: if momentum == 0:
raise NotImplementedError( raise NotImplementedError(
f"LARS without momentum is not supported!" "LARS without momentum is not supported!"
) )
super(LARS32bit, self).__init__( super().__init__(
"lars", "lars",
params, params,
lr, lr,
@ -123,12 +123,12 @@ class PytorchLARS(Optimizer):
max_unorm=0.02, max_unorm=0.02,
): ):
if lr < 0.0: if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError(f"Invalid learning rate: {lr}")
if momentum < 0.0: if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum)) raise ValueError(f"Invalid momentum value: {momentum}")
if weight_decay < 0.0: if weight_decay < 0.0:
raise ValueError( raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay) f"Invalid weight_decay value: {weight_decay}"
) )
defaults = dict( defaults = dict(
@ -143,10 +143,10 @@ class PytorchLARS(Optimizer):
raise ValueError( raise ValueError(
"Nesterov momentum requires a momentum and zero dampening" "Nesterov momentum requires a momentum and zero dampening"
) )
super(PytorchLARS, self).__init__(params, defaults) super().__init__(params, defaults)
def __setstate__(self, state): def __setstate__(self, state):
super(PytorchLARS, self).__setstate__(state) super().__setstate__(state)
for group in self.param_groups: for group in self.param_groups:
group.setdefault("nesterov", False) group.setdefault("nesterov", False)
@ -181,7 +181,7 @@ class PytorchLARS(Optimizer):
state = self.state[p] state = self.state[p]
d_p = p.grad d_p = p.grad
if weight_decay != 0: if weight_decay != 0:
d_p = d_p.add(param, alpha=weight_decay) d_p = d_p.add(p, alpha=weight_decay)
if momentum != 0: if momentum != 0:
buf = state.get("momentum_buffer", None) buf = state.get("momentum_buffer", None)

View File

@ -12,13 +12,13 @@ import torch
import bitsandbytes.functional as F import bitsandbytes.functional as F
class MockArgs(object): class MockArgs:
def __init__(self, initial_data): def __init__(self, initial_data):
for key in initial_data: for key in initial_data:
setattr(self, key, initial_data[key]) setattr(self, key, initial_data[key])
class GlobalOptimManager(object): class GlobalOptimManager:
_instance = None _instance = None
def __init__(self): def __init__(self):
@ -56,9 +56,9 @@ class GlobalOptimManager(object):
""" """
Overrides initial optimizer config for specific parameters. Overrides initial optimizer config for specific parameters.
The key-values of the optimizer config for the input parameters are overidden The key-values of the optimizer config for the input parameters are overridden
This can be both, optimizer parameters like "betas", or "lr" or it can be This can be both, optimizer parameters like "betas", or "lr" or it can be
8-bit specific paramters like "optim_bits", "percentile_clipping". 8-bit specific parameters like "optim_bits", "percentile_clipping".
Parameters Parameters
---------- ----------
@ -93,13 +93,12 @@ class GlobalOptimManager(object):
class Optimizer8bit(torch.optim.Optimizer): class Optimizer8bit(torch.optim.Optimizer):
def __init__(self, params, defaults, optim_bits=32): def __init__(self, params, defaults, optim_bits=32):
super(Optimizer8bit, self).__init__(params, defaults) super().__init__(params, defaults)
self.initialized = False self.initialized = False
self.name2qmap = {} self.name2qmap = {}
self.mng = GlobalOptimManager.get_instance() self.mng = GlobalOptimManager.get_instance()
self.non_castable_tensor_keys = set( self.non_castable_tensor_keys = {
[
"qmap1", "qmap1",
"qmap2", "qmap2",
"max1", "max1",
@ -112,8 +111,7 @@ class Optimizer8bit(torch.optim.Optimizer):
"absmax1", "absmax1",
"absmax2", "absmax2",
"unorm_vec", "unorm_vec",
] }
)
if optim_bits == 8: if optim_bits == 8:
self.fill_qmap() self.fill_qmap()
@ -123,7 +121,7 @@ class Optimizer8bit(torch.optim.Optimizer):
self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False) self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False)
def __setstate__(self, state): def __setstate__(self, state):
super(Optimizer8bit, self).__setstate__(state) super().__setstate__(state)
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
r"""Loads the optimizer state. r"""Loads the optimizer state.
@ -155,8 +153,8 @@ class Optimizer8bit(torch.optim.Optimizer):
id_map = { id_map = {
old_id: p old_id: p
for old_id, p in zip( for old_id, p in zip(
chain.from_iterable((g["params"] for g in saved_groups)), chain.from_iterable(g["params"] for g in saved_groups),
chain.from_iterable((g["params"] for g in groups)), chain.from_iterable(g["params"] for g in groups),
) )
} }
@ -284,11 +282,11 @@ class Optimizer8bit(torch.optim.Optimizer):
return config return config
def init_state(self, group, p, gindex, pindex): def init_state(self, group, p, gindex, pindex):
raise NotImplementedError(f"init_state method needs to be overidden") raise NotImplementedError("init_state method needs to be overridden")
def update_step(self, group, p, gindex, pindex): def update_step(self, group, p, gindex, pindex):
raise NotImplementedError( raise NotImplementedError(
f"The update_step method needs to be overidden" "The update_step method needs to be overridden"
) )
@ -310,9 +308,9 @@ class Optimizer2State(Optimizer8bit):
skip_zeros=False, skip_zeros=False,
): ):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError(f"Invalid epsilon value: {eps}")
if isinstance(betas, str): if isinstance(betas, str):
# format: '(beta1, beta2)' # format: '(beta1, beta2)'
betas = betas.replace("(", "").replace(")", "").strip().split(",") betas = betas.replace("(", "").replace(")", "").strip().split(",")
@ -324,10 +322,10 @@ class Optimizer2State(Optimizer8bit):
) )
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError( raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay) f"Invalid weight_decay value: {weight_decay}"
) )
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer2State, self).__init__(params, defaults, optim_bits) super().__init__(params, defaults, optim_bits)
if args is None: if args is None:
args = {} args = {}
@ -542,9 +540,9 @@ class Optimizer1State(Optimizer8bit):
skip_zeros=False, skip_zeros=False,
): ):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError(f"Invalid epsilon value: {eps}")
for i in range(len(betas)): for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0: if not 0.0 <= betas[i] < 1.0:
raise ValueError( raise ValueError(
@ -552,10 +550,10 @@ class Optimizer1State(Optimizer8bit):
) )
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError( raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay) f"Invalid weight_decay value: {weight_decay}"
) )
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer1State, self).__init__(params, defaults, optim_bits) super().__init__(params, defaults, optim_bits)
if args is None: if args is None:
args = {} args = {}

View File

@ -23,11 +23,11 @@ class RMSprop(Optimizer1State):
): ):
if alpha == 0: if alpha == 0:
raise NotImplementedError( raise NotImplementedError(
f"RMSprop with alpha==0.0 is not supported!" "RMSprop with alpha==0.0 is not supported!"
) )
if centered: if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!") raise NotImplementedError("Centered RMSprop is not supported!")
super(RMSprop, self).__init__( super().__init__(
"rmsprop", "rmsprop",
params, params,
lr, lr,
@ -59,11 +59,11 @@ class RMSprop8bit(Optimizer1State):
): ):
if alpha == 0: if alpha == 0:
raise NotImplementedError( raise NotImplementedError(
f"RMSprop with alpha==0.0 is not supported!" "RMSprop with alpha==0.0 is not supported!"
) )
if centered: if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!") raise NotImplementedError("Centered RMSprop is not supported!")
super(RMSprop8bit, self).__init__( super().__init__(
"rmsprop", "rmsprop",
params, params,
lr, lr,
@ -96,11 +96,11 @@ class RMSprop32bit(Optimizer1State):
if alpha == 0: if alpha == 0:
raise NotImplementedError( raise NotImplementedError(
f"RMSprop with alpha==0.0 is not supported!" "RMSprop with alpha==0.0 is not supported!"
) )
if centered: if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!") raise NotImplementedError("Centered RMSprop is not supported!")
super(RMSprop32bit, self).__init__( super().__init__(
"rmsprop", "rmsprop",
params, params,
lr, lr,

View File

@ -21,8 +21,8 @@ class SGD(Optimizer1State):
block_wise=True, block_wise=True,
): ):
if momentum == 0: if momentum == 0:
raise NotImplementedError(f"SGD without momentum is not supported!") raise NotImplementedError("SGD without momentum is not supported!")
super(SGD, self).__init__( super().__init__(
"momentum", "momentum",
params, params,
lr, lr,
@ -52,8 +52,8 @@ class SGD8bit(Optimizer1State):
block_wise=True, block_wise=True,
): ):
if momentum == 0: if momentum == 0:
raise NotImplementedError(f"SGD without momentum is not supported!") raise NotImplementedError("SGD without momentum is not supported!")
super(SGD8bit, self).__init__( super().__init__(
"momentum", "momentum",
params, params,
lr, lr,
@ -83,8 +83,8 @@ class SGD32bit(Optimizer1State):
block_wise=True, block_wise=True,
): ):
if momentum == 0: if momentum == 0:
raise NotImplementedError(f"SGD without momentum is not supported!") raise NotImplementedError("SGD without momentum is not supported!")
super(SGD32bit, self).__init__( super().__init__(
"momentum", "momentum",
params, params,
lr, lr,

View File

@ -121,5 +121,3 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
#endif #endif

View File

@ -290,4 +290,3 @@ extern "C"
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); }
} }

View File

@ -76,6 +76,3 @@ if [[ -n "$CUDA_VERSION" ]]; then
else else
echo "" echo ""
fi fi

View File

@ -26,9 +26,6 @@ setup(
keywords="gpu optimizers optimization 8-bit quantization compression", keywords="gpu optimizers optimization 8-bit quantization compression",
url="https://github.com/TimDettmers/bitsandbytes", url="https://github.com/TimDettmers/bitsandbytes",
packages=find_packages(), packages=find_packages(),
entry_points={
"console_scripts": ["debug_cuda = bitsandbytes.debug_cli:cli"],
},
package_data={"": libs}, package_data={"": libs},
long_description=read("README.md"), long_description=read("README.md"),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",

View File

@ -1,4 +1,4 @@
from itertools import product, permutations from itertools import permutations, product
import pytest import pytest
import torch import torch
@ -27,7 +27,7 @@ str_values = list(
) )
) )
names = [ names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format( "dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(
*vals *vals
) )
for vals in str_values for vals in str_values
@ -286,7 +286,7 @@ str_values = list(
has_bias has_bias
) )
) )
names = ["dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}_has_bias_{10}".format(*vals) for vals in str_values] names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_decomp_{}_has_fp16_weights_{}_has_bias_{}".format(*vals) for vals in str_values]
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@ -1,13 +1,13 @@
import os import os
import pytest
import bitsandbytes as bnb
from typing import List, NamedTuple from typing import List, NamedTuple
import pytest
import bitsandbytes as bnb
from bitsandbytes.cuda_setup import ( from bitsandbytes.cuda_setup import (
CUDA_RUNTIME_LIB, CUDA_RUNTIME_LIB,
evaluate_cuda_setup,
determine_cuda_runtime_lib_path, determine_cuda_runtime_lib_path,
evaluate_cuda_setup,
extract_candidate_paths, extract_candidate_paths,
) )

View File

@ -28,7 +28,7 @@ def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0):
class FFN(torch.nn.Module): class FFN(torch.nn.Module):
def __init__(self, input_features, hidden_size, bias=True): def __init__(self, input_features, hidden_size, bias=True):
super(FFN, self).__init__() super().__init__()
self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias) self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias)
self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias) self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias)
@ -42,7 +42,7 @@ class FFN(torch.nn.Module):
return x return x
class Timer(object): class Timer:
def __init__(self): def __init__(self):
self.starts = {} self.starts = {}
self.ends = {} self.ends = {}
@ -69,7 +69,7 @@ class Timer(object):
self.ends.pop(name) self.ends.pop(name)
if print_ms and name in self.agg: if print_ms and name in self.agg:
print("{0} took: {1:.5f}s".format(name, self.agg[name] / 1000.0)) print(f"{name} took: {self.agg[name] / 1000.0:.5f}s")
return self.agg[name] return self.agg[name]
@ -302,7 +302,7 @@ batched = [False, True]
values = list(product(dim1, dim2, methods, batched)) values = list(product(dim1, dim2, methods, batched))
values_names = list(product(dim1, dim2, method_names, batched)) values_names = list(product(dim1, dim2, method_names, batched))
names = [ names = [
"dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals) "dim1_{}_dim2_{}_quant_{}_batched_{}".format(*vals)
for vals in values_names for vals in values_names
] ]
@ -360,7 +360,7 @@ seq_dim = torch.randint(16, 256, size=(n,)).tolist()
transpose = [(False, False), (False, True), (True, False), (True, True)] transpose = [(False, False), (False, True), (True, False), (True, True)]
values = list(product(hidden_dim, batch_dim, transpose, seq_dim)) values = list(product(hidden_dim, batch_dim, transpose, seq_dim))
names = [ names = [
"hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}".format(*vals) "hidden_dim_{}_batch_dim_{},transpose_{}_seq_dim_{}".format(*vals)
for vals in values for vals in values
] ]
@ -425,7 +425,7 @@ hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist() batch_dim = torch.randint(2, 16, size=(n,)).tolist()
values = list(product(seq_dim, hidden_dim, batch_dim)) values = list(product(seq_dim, hidden_dim, batch_dim))
names = [ names = [
"seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values "seq_dim{}_hidden_dim{}_batch_dim{}".format(*vals) for vals in values
] ]
@ -457,7 +457,7 @@ batch_dim = torch.randint(2, 16, size=(n,)).tolist()
transpose = [False, True] transpose = [False, True]
values = list(product(seq_dim, hidden_dim, batch_dim, transpose)) values = list(product(seq_dim, hidden_dim, batch_dim, transpose))
names = [ names = [
"seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}".format(*vals) "seq_dim={}_hidden_dim={}_batch_dim={}_transpose{}".format(*vals)
for vals in values for vals in values
] ]
@ -542,7 +542,7 @@ dim4 = torch.randint(32, 256, size=(n,)).tolist()
transpose = [(False, False), (True, False), (False, True), (True, True)] transpose = [(False, False), (True, False), (False, True), (True, True)]
values = list(product(dim1, dim2, dim3, dim4, transpose)) values = list(product(dim1, dim2, dim3, dim4, transpose))
names = [ names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals) "dim1_{}_dim2_{}_dim3_{}_dim4_{}_transpose_{}".format(*vals)
for vals in values for vals in values
] ]
@ -580,7 +580,7 @@ dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist() dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist() dim3 = torch.randint(32, 256, size=(n,)).tolist()
values = list(product(dim1, dim2, dim3)) values = list(product(dim1, dim2, dim3))
names = ["dim1_{0}_dim2_{1}_dim3_{2}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}_dim3_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names) @pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names)
@ -609,7 +609,7 @@ transpose = [False]
dims = [2, 3] dims = [2, 3]
values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)) values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))
names = ["dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}".format(*vals)for vals in values] names = ["dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_transpose_{}".format(*vals)for vals in values]
@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",values,ids=names) @pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",values,ids=names)
@ -691,7 +691,7 @@ ldb = [0]
# ldb = list(range(256, 1*1024, 256)) # ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims, ldb)) values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
names = [ names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}".format(*vals) "dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}_ldb_{}".format(*vals)
for vals in values for vals in values
] ]
@ -739,7 +739,7 @@ dims = (2,)
# ldb = list(range(256, 1*1024, 256)) # ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims)) values = list(product(dim1, dim2, dim3, dim4, dims))
names = [ names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals) "dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}".format(*vals)
for vals in values for vals in values
] ]
@ -797,7 +797,7 @@ values = [
# values = list(product(batch, seq, model, hidden)) # values = list(product(batch, seq, model, hidden))
names = [ names = [
"batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values "batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values
] ]
@ -965,7 +965,7 @@ dims = (2,)
formatB = ["col_turing", "col_ampere"] formatB = ["col_turing", "col_ampere"]
has_bias = [True, False] has_bias = [True, False]
values = list(product(dim1, dim4, dims, formatB, has_bias)) values = list(product(dim1, dim4, dims, formatB, has_bias))
names = ["dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}_has_bias_{4}".format(*vals) for vals in values] names = ["dim1_{}_dim4_{}_dims_{}_formatB_{}_has_bias_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names) @pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names)
@ -1015,7 +1015,7 @@ dim2 = [1 * 1024]
dims = (2,) dims = (2,)
# ldb = list(range(256, 1*1024, 256)) # ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dims)) values = list(product(dim1, dim2, dims))
names = ["dim1_{0}_dim2_{1}_dims_{2}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}_dims_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names) @pytest.mark.parametrize("dim1, dim2, dims", values, ids=names)
@ -1071,7 +1071,7 @@ dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist() dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2)) values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names) @pytest.mark.parametrize("dim1, dim2", values, ids=names)
@ -1118,7 +1118,7 @@ dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist() inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(zip(dim1, dim4, inner)) values = list(zip(dim1, dim4, inner))
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values] names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) @pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
@ -1162,7 +1162,7 @@ dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist() inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(zip(dim1, dim4, inner)) values = list(zip(dim1, dim4, inner))
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values] names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) @pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
@ -1237,7 +1237,7 @@ inner = [12288 * 4, 4096 * 4]
dim4 = [12288, 4096] dim4 = [12288, 4096]
values = list(zip(dim1, dim4, inner)) values = list(zip(dim1, dim4, inner))
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values] names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) @pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
@ -1303,7 +1303,7 @@ values = list(
product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose) product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
) )
names = [ names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}".format( "dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_{}".format(
*vals *vals
) )
for vals in values for vals in values
@ -1354,7 +1354,7 @@ a_order = ["col_turing"]
out_order = ["row"] out_order = ["row"]
values = list(product(dim1, dim2, dtype, a_order, out_order)) values = list(product(dim1, dim2, dtype, a_order, out_order))
names = [ names = [
"dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}".format(*vals) "dim1_{}_dim2_{}_dtype_{}_orderA_{}_orderOut_{}".format(*vals)
for vals in values for vals in values
] ]
@ -1380,7 +1380,7 @@ dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
# dim2 = [5] # dim2 = [5]
values = list(product(dim1, dim2)) values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names) @pytest.mark.parametrize("dim1, dim2", values, ids=names)
@ -1417,7 +1417,7 @@ dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
# dim2 = [11] # dim2 = [11]
transposed_B = [False, True] transposed_B = [False, True]
values = list(product(dim1, dim2, transposed_B)) values = list(product(dim1, dim2, transposed_B))
names = ["dim1_{0}_dim2_{1}_transposed_B_{2}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}_transposed_B_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names) @pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names)
@ -1498,7 +1498,7 @@ n = 2
dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist() dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist() dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2)) values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names) @pytest.mark.parametrize("dim1, dim2", values, ids=names)
@ -1563,7 +1563,7 @@ dtype = [torch.float16]
out_function = ["zeros", "ones"] out_function = ["zeros", "ones"]
values = list(product(dim1, dim2, dtype, out_function)) values = list(product(dim1, dim2, dtype, out_function))
names = [ names = [
"dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values "dim1_{}_dim2_{}_dtype_{}_out_func_{}".format(*vals) for vals in values
] ]
@ -1680,7 +1680,7 @@ dim2 = [2048]
# dim2 = [2] # dim2 = [2]
dtype = [torch.int8] dtype = [torch.int8]
values = list(product(dim1, dim2, dtype)) values = list(product(dim1, dim2, dtype))
names = ["dim1_{0}_dim2_{1}_dtype_{2}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}_dtype_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names) @pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names)
@ -1796,7 +1796,7 @@ values.append((batch_size, seqdim, 768, 4 * 768))
# values.append((batch_size, seqdim, 5140, 4*5140)) # values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 12288, 4*12288)) #values.append((batch_size, seqdim, 12288, 4*12288))
names = [ names = [
"batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values "batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values
] ]

View File

@ -7,7 +7,7 @@ from torch import nn
import bitsandbytes as bnb import bitsandbytes as bnb
class MockArgs(object): class MockArgs:
def __init__(self, initial_data): def __init__(self, initial_data):
for key in initial_data: for key in initial_data:
setattr(self, key, initial_data[key]) setattr(self, key, initial_data[key])
@ -15,7 +15,7 @@ class MockArgs(object):
class MLP8bit(torch.nn.Module): class MLP8bit(torch.nn.Module):
def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0): def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
super(MLP8bit, self).__init__() super().__init__()
self.fc1 = bnb.nn.Linear8bitLt( self.fc1 = bnb.nn.Linear8bitLt(
dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
threshold=threshold threshold=threshold
@ -289,7 +289,7 @@ class LinearFunction(torch.autograd.Function):
class Linear8bit(nn.Module): class Linear8bit(nn.Module):
def __init__(self, input_features, output_features, bias=True, args=None): def __init__(self, input_features, output_features, bias=True, args=None):
super(Linear8bit, self).__init__() super().__init__()
self.input_features = input_features self.input_features = input_features
self.output_features = output_features self.output_features = output_features
self.args = args self.args = args
@ -312,7 +312,7 @@ class Linear8bit(nn.Module):
threshold = [0.0, 3.0] threshold = [0.0, 3.0]
values = threshold values = threshold
names = ["threshold_{0}".format(vals) for vals in values] names = [f"threshold_{vals}" for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names) @pytest.mark.parametrize("threshold", values, ids=names)
@ -378,7 +378,7 @@ def test_linear8bitlt_accumulated_gradient():
threshold = [0.0, 2.0] threshold = [0.0, 2.0]
values = threshold values = threshold
names = ["threshold_{0}".format(vals) for vals in values] names = [f"threshold_{vals}" for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names) @pytest.mark.parametrize("threshold", values, ids=names)

View File

@ -18,7 +18,7 @@ k = 20
def get_temp_dir(): def get_temp_dir():
path = "/tmp/autoswap/{0}".format(str(uuid.uuid4())) path = f"/tmp/autoswap/{str(uuid.uuid4())}"
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
return path return path
@ -116,7 +116,7 @@ gtype = [torch.float32, torch.float16]
optimizer_names = ["adam", "momentum", "rmsprop", "lars"] optimizer_names = ["adam", "momentum", "rmsprop", "lars"]
values = list(product(dim1, dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = [ names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
] ]
@ -187,7 +187,7 @@ dim1 = [1024]
dim2 = [32, 1024, 4097] dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16] gtype = [torch.float32, torch.float16]
values = list(product(dim1, dim2, gtype)) values = list(product(dim1, dim2, gtype))
names = ["dim1_{0}_dim2_{1}_gtype_{2}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}_gtype_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names)
@ -250,7 +250,7 @@ optimizer_names = [
] ]
values = list(product(dim1, dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = [ names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
] ]
@ -391,7 +391,7 @@ gtype = [torch.float32]
optim_bits = [32, 8] optim_bits = [32, 8]
values = list(product(dim1, dim2, gtype, optim_bits)) values = list(product(dim1, dim2, gtype, optim_bits))
names = [ names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals) "dim1_{}_dim2_{}_gtype_{}_optim_bits_{}".format(*vals)
for vals in values for vals in values
] ]
@ -495,7 +495,7 @@ gtype = [torch.float32, torch.float16]
optimizer_names = ["adam8bit_blockwise"] optimizer_names = ["adam8bit_blockwise"]
values = list(product(dim1, dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = [ names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
] ]

View File

@ -1,149 +0,0 @@
./setup.py:20:10: F541 f-string is missing placeholders
./setup.py:21:13: F541 f-string is missing placeholders
./quicktest.py:5:1: F401 'bitsandbytes as bnb' imported but unused
./bitsandbytes/cuda_setup.py:42:56: F821 undefined name 'error_str'
./bitsandbytes/cuda_setup.py:43:15: F541 f-string is missing placeholders
./bitsandbytes/cuda_setup.py:67:5: F841 local variable 'context' is assigned to but never used
./bitsandbytes/cuda_setup.py:68:5: F841 local variable 'error_str' is assigned to but never used
./bitsandbytes/cuda_setup.py:76:9: F841 local variable 'result' is assigned to but never used
./bitsandbytes/cuda_setup.py:144:13: F841 local variable 'has_gpu' is assigned to but never used
./bitsandbytes/functional.py:294:13: F821 undefined name 'math'
./bitsandbytes/functional.py:295:16: F821 undefined name 'math'
./bitsandbytes/functional.py:303:5: F841 local variable 'ptrA' is assigned to but never used
./bitsandbytes/functional.py:304:5: F841 local variable 'ptrOut' is assigned to but never used
./bitsandbytes/functional.py:1057:17: W503 line break before binary operator
./bitsandbytes/functional.py:1058:17: W503 line break before binary operator
./bitsandbytes/functional.py:1059:17: W503 line break before binary operator
./bitsandbytes/functional.py:1649:1: F811 redefinition of unused 'get_special_format_str' from line 160
./bitsandbytes/functional.py:1687:5: F841 local variable 'ptrA' is assigned to but never used
./bitsandbytes/functional.py:1688:5: F841 local variable 'ptrOut' is assigned to but never used
./bitsandbytes/functional.py:1802:5: F841 local variable 'ccolsA' is assigned to but never used
./bitsandbytes/functional.py:1805:5: F841 local variable 'cldb' is assigned to but never used
./bitsandbytes/functional.py:1806:5: F841 local variable 'cldc' is assigned to but never used
./bitsandbytes/functional.py:1873:9: F841 local variable 'dtype' is assigned to but never used
./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.MatmulLtState' imported but unused
./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.bmm_cublas' imported but unused
./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.matmul' imported but unused
./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.matmul_cublas' imported but unused
./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.mm_cublas' imported but unused
./bitsandbytes/__init__.py:9:1: F401 '.nn.modules' imported but unused
./bitsandbytes/__init__.py:12:5: F401 '.optim.adam' imported but unused
./bitsandbytes/autograd/_functions.py:5:1: F401 'bitsandbytes as bnb' imported but unused
./bitsandbytes/autograd/_functions.py:12:75: W291 trailing whitespace
./bitsandbytes/nn/__init__.py:5:1: F401 '.modules.Int8Params' imported but unused
./bitsandbytes/nn/__init__.py:5:1: F401 '.modules.Linear8bit' imported but unused
./bitsandbytes/nn/__init__.py:5:1: F401 '.modules.Linear8bitLt' imported but unused
./bitsandbytes/nn/__init__.py:5:1: F401 '.modules.StableEmbedding' imported but unused
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Any' imported but unused
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Callable' imported but unused
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Dict' imported but unused
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Iterator' imported but unused
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Mapping' imported but unused
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Set' imported but unused
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Tuple' imported but unused
./bitsandbytes/nn/modules.py:11:1: F401 'torch.nn.parameter.Parameter' imported but unused
./bitsandbytes/nn/modules.py:183:13: W503 line break before binary operator
./bitsandbytes/nn/modules.py:184:13: W503 line break before binary operator
./bitsandbytes/nn/modules.py:272:24: F821 undefined name 'dist'
./bitsandbytes/nn/modules.py:272:49: F821 undefined name 'dist'
./bitsandbytes/optim/optimizer.py:243:9: F841 local variable 'overflows' is assigned to but never used
./bitsandbytes/optim/optimizer.py:280:35: F541 f-string is missing placeholders
./bitsandbytes/optim/optimizer.py:283:35: F541 f-string is missing placeholders
./bitsandbytes/optim/lars.py:27:39: F541 f-string is missing placeholders
./bitsandbytes/optim/lars.py:59:39: F541 f-string is missing placeholders
./bitsandbytes/optim/lars.py:91:39: F541 f-string is missing placeholders
./bitsandbytes/optim/lars.py:157:13: F841 local variable 'params_with_grad' is assigned to but never used
./bitsandbytes/optim/lars.py:158:13: F841 local variable 'd_p_list' is assigned to but never used
./bitsandbytes/optim/lars.py:159:13: F841 local variable 'momentum_buffer_list' is assigned to but never used
./bitsandbytes/optim/lars.py:174:35: F821 undefined name 'param'
./bitsandbytes/optim/__init__.py:9:5: F401 '.adam.Adam' imported but unused
./bitsandbytes/optim/__init__.py:9:5: F401 '.adam.Adam8bit' imported but unused
./bitsandbytes/optim/__init__.py:9:5: F401 '.adam.Adam32bit' imported but unused
./bitsandbytes/optim/__init__.py:10:5: F401 '.adamw.AdamW' imported but unused
./bitsandbytes/optim/__init__.py:10:5: F401 '.adamw.AdamW8bit' imported but unused
./bitsandbytes/optim/__init__.py:10:5: F401 '.adamw.AdamW32bit' imported but unused
./bitsandbytes/optim/__init__.py:11:5: F401 '.sgd.SGD' imported but unused
./bitsandbytes/optim/__init__.py:11:5: F401 '.sgd.SGD8bit' imported but unused
./bitsandbytes/optim/__init__.py:11:5: F401 '.sgd.SGD32bit' imported but unused
./bitsandbytes/optim/__init__.py:12:5: F401 '.lars.LARS' imported but unused
./bitsandbytes/optim/__init__.py:12:5: F401 '.lars.LARS8bit' imported but unused
./bitsandbytes/optim/__init__.py:12:5: F401 '.lars.LARS32bit' imported but unused
./bitsandbytes/optim/__init__.py:12:5: F401 '.lars.PytorchLARS' imported but unused
./bitsandbytes/optim/__init__.py:13:5: F401 '.lamb.LAMB' imported but unused
./bitsandbytes/optim/__init__.py:13:5: F401 '.lamb.LAMB8bit' imported but unused
./bitsandbytes/optim/__init__.py:13:5: F401 '.lamb.LAMB32bit' imported but unused
./bitsandbytes/optim/__init__.py:14:5: F401 '.rmsprop.RMSprop' imported but unused
./bitsandbytes/optim/__init__.py:14:5: F401 '.rmsprop.RMSprop8bit' imported but unused
./bitsandbytes/optim/__init__.py:14:5: F401 '.rmsprop.RMSprop32bit' imported but unused
./bitsandbytes/optim/__init__.py:15:5: F401 '.adagrad.Adagrad' imported but unused
./bitsandbytes/optim/__init__.py:15:5: F401 '.adagrad.Adagrad8bit' imported but unused
./bitsandbytes/optim/__init__.py:15:5: F401 '.adagrad.Adagrad32bit' imported but unused
./bitsandbytes/optim/__init__.py:17:1: F401 '.optimizer.GlobalOptimManager' imported but unused
./bitsandbytes/optim/adam.py:229:21: F841 local variable 'max_exp_avg_sq' is assigned to but never used
./bitsandbytes/optim/rmsprop.py:25:39: F541 f-string is missing placeholders
./bitsandbytes/optim/rmsprop.py:27:39: F541 f-string is missing placeholders
./bitsandbytes/optim/rmsprop.py:59:39: F541 f-string is missing placeholders
./bitsandbytes/optim/rmsprop.py:61:39: F541 f-string is missing placeholders
./bitsandbytes/optim/rmsprop.py:94:39: F541 f-string is missing placeholders
./bitsandbytes/optim/rmsprop.py:96:39: F541 f-string is missing placeholders
./bitsandbytes/optim/sgd.py:24:39: F541 f-string is missing placeholders
./bitsandbytes/optim/sgd.py:55:39: F541 f-string is missing placeholders
./bitsandbytes/optim/sgd.py:86:39: F541 f-string is missing placeholders
./tests/test_optim.py:1:1: F401 'ctypes' imported but unused
./tests/test_optim.py:199:5: F841 local variable 'mask' is assigned to but never used
./tests/test_optim.py:218:9: F841 local variable 'atol' is assigned to but never used
./tests/test_optim.py:218:15: F841 local variable 'rtol' is assigned to but never used
./tests/test_optim.py:304:17: W503 line break before binary operator
./tests/test_optim.py:354:21: W503 line break before binary operator
./tests/test_autograd.py:309:13: F841 local variable 'err' is assigned to but never used
./tests/test_cuda_setup_evaluator.py:31:9: F821 undefined name 'test_dir'
./tests/test_cuda_setup_evaluator.py:33:14: F821 undefined name 'test_input'
./tests/test_cuda_setup_evaluator.py:81:32: E203 whitespace before ':'
./tests/test_functional.py:55:13: F841 local variable 'ms' is assigned to but never used
./tests/test_functional.py:177:5: F841 local variable 'diffs' is assigned to but never used
./tests/test_functional.py:178:5: F841 local variable 'reldiffs' is assigned to but never used
./tests/test_functional.py:260:5: F841 local variable 'minA' is assigned to but never used
./tests/test_functional.py:261:5: F841 local variable 'maxA' is assigned to but never used
./tests/test_functional.py:584:5: F841 local variable 'func' is assigned to but never used
./tests/test_functional.py:617:17: F841 local variable 'offset' is assigned to but never used
./tests/test_functional.py:618:17: F841 local variable 'col2' is assigned to but never used
./tests/test_functional.py:619:17: F841 local variable 'row2' is assigned to but never used
./tests/test_functional.py:705:9: F841 local variable 'C1' is assigned to but never used
./tests/test_functional.py:706:9: F841 local variable 'C2' is assigned to but never used
./tests/test_functional.py:715:9: F841 local variable 'output' is assigned to but never used
./tests/test_functional.py:750:5: F841 local variable 'formatB' is assigned to but never used
./tests/test_functional.py:754:5: F841 local variable 'w2' is assigned to but never used
./tests/test_functional.py:763:5: F841 local variable 'dtype' is assigned to but never used
./tests/test_functional.py:770:9: F841 local variable 'out1' is assigned to but never used
./tests/test_functional.py:1108:5: F841 local variable 'relerr1' is assigned to but never used
./tests/test_functional.py:1108:14: F841 local variable 'relerr2' is assigned to but never used
./tests/test_functional.py:1114:9: F841 local variable 'C1' is assigned to but never used
./tests/test_functional.py:1135:9: F841 local variable 'C4' is assigned to but never used
./tests/test_functional.py:1179:5: F841 local variable 'err1' is assigned to but never used
./tests/test_functional.py:1179:11: F841 local variable 'err2' is assigned to but never used
./tests/test_functional.py:1179:17: F841 local variable 'err3' is assigned to but never used
./tests/test_functional.py:1180:5: F841 local variable 'relerr1' is assigned to but never used
./tests/test_functional.py:1180:14: F841 local variable 'relerr2' is assigned to but never used
./tests/test_functional.py:1192:9: F841 local variable 'C1' is assigned to but never used
./tests/test_functional.py:1313:9: F841 local variable 'c' is assigned to but never used
./tests/test_functional.py:1314:9: F841 local variable 'c2' is assigned to but never used
./tests/test_functional.py:1406:9: F841 local variable 'C1' is assigned to but never used
./tests/test_functional.py:1425:9: F841 local variable 'out2' is assigned to but never used
./tests/test_functional.py:1542:5: F841 local variable 'idx_col' is assigned to but never used
./tests/test_functional.py:1566:30: E203 whitespace before ':'
./tests/test_functional.py:1568:38: E203 whitespace before ':'
./tests/test_functional.py:1655:5: F841 local variable 'offset' is assigned to but never used
./tests/test_functional.py:1706:9: F841 local variable 'out' is assigned to but never used
./tests/test_functional.py:1822:9: F841 local variable 'out' is assigned to but never used
./tests/test_functional.py:1882:5: F841 local variable 'out2' is assigned to but never used
./tests/test_functional.py:1928:9: F841 local variable 'dtype' is assigned to but never used
./tests/test_functional.py:1934:9: F841 local variable 'minx' is assigned to but never used
./tests/test_functional.py:1948:5: F841 local variable 'C0' is assigned to but never used
./tests/test_modules.py:1:1: F401 'itertools.product' imported but unused
./tests/test_modules.py:52:9: F841 local variable 'norm' is assigned to but never used
./tests/test_modules.py:52:16: F821 undefined name 'math'
./tests/test_modules.py:52:26: F821 undefined name 'math'
./tests/test_modules.py:52:37: F821 undefined name 'math'
./tests/test_modules.py:177:21: F821 undefined name 'einops'
./tests/test_modules.py:233:9: F841 local variable 'stochastic' is assigned to but never used
./tests/test_modules.py:382:9: F841 local variable 'o1' is assigned to but never used