Merge branch 'main' into main

This commit is contained in:
Tim Dettmers 2023-01-02 11:23:17 +01:00 committed by GitHub
commit be5cecb88f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
41 changed files with 746 additions and 631 deletions

View File

@ -149,3 +149,9 @@ Bug fixes:
Bug fixes:
- Fixed a bug in the CUDA Setup which led to an incomprehensible error if no GPU was detected.
### 0.35.4
Bug fixes:
- Fixed a bug in the CUDA Setup failed with the cuda runtime was found, but not the cuda library.
- Fixed a bug where not finding the cuda runtime led to an incomprehensible error.

View File

@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import cuda_setup, utils
from .autograd._functions import (
MatmulLtState,
bmm_cublas,
@ -12,7 +13,6 @@ from .autograd._functions import (
)
from .cextension import COMPILED_WITH_CUDA
from .nn import modules
from . import cuda_setup, utils
if COMPILED_WITH_CUDA:
from .optim import adam

View File

@ -1,6 +1,3 @@
# from bitsandbytes.debug_cli import cli
# cli()
import os
import sys
from warnings import warn
@ -31,8 +28,8 @@ print()
from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL
from .cuda_setup.main import get_compute_capabilities, get_cuda_lib_handle
from .cuda_setup.env_vars import to_be_ignored
from .cuda_setup.main import get_compute_capabilities, get_cuda_lib_handle
print_header("POTENTIALLY LIBRARY-PATH-LIKE ENV VARS")
for k, v in os.environ.items():

View File

@ -1,12 +1,13 @@
import operator
import warnings
import torch
import bitsandbytes.functional as F
from dataclasses import dataclass
from functools import reduce # Required in Python 3
import torch
import bitsandbytes.functional as F
# math.prod not compatible with python < 3.8
def prod(iterable):
return reduce(operator.mul, iterable, 1)
@ -18,7 +19,7 @@ tensor = torch.Tensor
This is particularly important for small models where outlier features
are less systematic and occur with low frequency.
"""
class GlobalOutlierPooler(object):
class GlobalOutlierPooler:
_instance = None
def __init__(self):
@ -49,8 +50,9 @@ class GlobalOutlierPooler(object):
class MatMul8bit(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, quant_type="vector", precision=[8, 8, 8]):
def forward(ctx, A, B, out=None, quant_type="vector", precision=None):
if precision is None:
precision = [8, 8, 8]
if precision[0] != 8:
with torch.no_grad():
output = torch.matmul(A, B)

View File

@ -1,11 +1,11 @@
import ctypes as ct
import torch
from pathlib import Path
from warnings import warn
import torch
class CUDASetup(object):
class CUDASetup:
_instance = None
def __init__(self):
@ -52,8 +52,13 @@ class CUDASetup(object):
self.add_log_entry('python setup.py install')
def initialize(self):
self.cuda_setup_log = []
self.has_printed = False
self.lib = None
self.run_cuda_setup()
def run_cuda_setup(self):
self.initialized = True
self.cuda_setup_log = []
from .cuda_setup.main import evaluate_cuda_setup
binary_name, cudart_path, cuda, cc, cuda_version_string = evaluate_cuda_setup()
@ -89,7 +94,8 @@ class CUDASetup(object):
else:
self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...")
self.lib = ct.cdll.LoadLibrary(binary_path)
except:
except Exception as ex:
self.add_log_entry(str(ex))
self.print_log_stack()
def add_log_entry(self, msg, is_warning=False):
@ -116,7 +122,7 @@ try:
CUDASetup.get_instance().generate_instructions()
CUDASetup.get_instance().print_log_stack()
raise RuntimeError('''
CUDA Setup failed despite GPU being available. Inspect the CUDA SETUP outputs to fix your environment!
CUDA Setup failed despite GPU being available. Inspect the CUDA SETUP outputs above to fix your environment!
If you cannot find any issues and suspect a bug, please open an issue with detals about your environment:
https://github.com/TimDettmers/bitsandbytes/issues''')
lib.cadam32bit_g32
@ -124,8 +130,6 @@ try:
lib.get_cusparse.restype = ct.c_void_p
COMPILED_WITH_CUDA = True
except AttributeError:
warn(
"The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers and GPU quantization are unavailable."
)
warn("The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers and GPU quantization are unavailable.")
COMPILED_WITH_CUDA = False

View File

@ -1,2 +1,6 @@
from .paths import CUDA_RUNTIME_LIB, extract_candidate_paths, determine_cuda_runtime_lib_path
from .main import evaluate_cuda_setup
from .paths import (
CUDA_RUNTIME_LIB,
determine_cuda_runtime_lib_path,
extract_candidate_paths,
)

View File

@ -19,9 +19,12 @@ evaluation:
import ctypes
import os
from .paths import determine_cuda_runtime_lib_path
import torch
from bitsandbytes.cextension import CUDASetup
from .paths import determine_cuda_runtime_lib_path
def check_cuda_result(cuda, result_val):
# 3. Check for CUDA errors
@ -30,8 +33,11 @@ def check_cuda_result(cuda, result_val):
cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
CUDASetup.get_instance().add_log_entry(f"CUDA exception! Error code: {error_str.value.decode()}")
def get_cuda_version(cuda, cudart_path):
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
def get_cuda_version(cuda, cudart_path):
if cuda is None: return None
try:
cudart = ctypes.CDLL(cudart_path)
except OSError:
@ -45,7 +51,7 @@ def get_cuda_version(cuda, cudart_path):
minor = (version-(major*1000))//10
if major < 11:
CUDASetup.get_instance().add_log_entry('CUDA SETUP: CUDA version lower than 11 are currenlty not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!')
CUDASetup.get_instance().add_log_entry('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!')
return f'{major}{minor}'
@ -73,7 +79,6 @@ def get_compute_capabilities(cuda):
# bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
"""
nGpus = ctypes.c_int()
cc_major = ctypes.c_int()
cc_minor = ctypes.c_int()
@ -100,11 +105,11 @@ def get_compute_capability(cuda):
capabilities are downwards compatible. If no GPUs are detected, it returns
None.
"""
ccs = get_compute_capabilities(cuda)
if ccs:
if cuda is None: return None
# TODO: handle different compute capabilities; for now, take the max
return ccs[-1]
return None
ccs = get_compute_capabilities(cuda)
if ccs: return ccs[-1]
def evaluate_cuda_setup():
@ -113,31 +118,32 @@ def evaluate_cuda_setup():
print('='*35 + 'BUG REPORT' + '='*35)
print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
print('To hide this message, set the BITSANDBYTES_NOWELCOME variable like so: export BITSANDBYTES_NOWELCOME=1')
print('='*80)
# if not torch.cuda.is_available():
# print('No GPU detected. Loading CPU library...')
# return binary_name
binary_name = "libbitsandbytes_cpu.so"
if not torch.cuda.is_available(): return 'libsbitsandbytes_cpu.so', None, None, None, None
cuda_setup = CUDASetup.get_instance()
cudart_path = determine_cuda_runtime_lib_path()
if cudart_path is None:
cuda_setup.add_log_entry("WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!", is_warning=True)
return binary_name
cuda_setup.add_log_entry((f"CUDA SETUP: CUDA runtime path found: {cudart_path}"))
cuda = get_cuda_lib_handle()
cc = get_compute_capability(cuda)
cuda_setup.add_log_entry(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")
cuda_version_string = get_cuda_version(cuda, cudart_path)
failure = False
if cudart_path is None:
failure = True
cuda_setup.add_log_entry("WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!", is_warning=True)
else:
cuda_setup.add_log_entry(f"CUDA SETUP: CUDA runtime path found: {cudart_path}")
if cc == '' or cc is None:
cuda_setup.add_log_entry("WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library...", is_warning=True)
return binary_name, cudart_path, cuda, cc, cuda_version_string
failure = True
cuda_setup.add_log_entry("WARNING: No GPU detected! Check your CUDA paths. Proceeding to load CPU-only library...", is_warning=True)
else:
cuda_setup.add_log_entry(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")
if cuda is None:
failure = True
else:
cuda_setup.add_log_entry(f'CUDA SETUP: Detected CUDA version {cuda_version_string}')
# 7.5 is the minimum CC vor cublaslt
has_cublaslt = cc in ["7.5", "8.0", "8.6"]
@ -148,16 +154,13 @@ def evaluate_cuda_setup():
# we use ls -l instead of nvcc to determine the cuda version
# since most installations will have the libcudart.so installed, but not the compiler
cuda_setup.add_log_entry(f'CUDA SETUP: Detected CUDA version {cuda_version_string}')
def get_binary_name():
"if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so"
bin_base_name = "libbitsandbytes_cuda"
if has_cublaslt:
return f"{bin_base_name}{cuda_version_string}.so"
if failure:
binary_name = "libbitsandbytes_cpu.so"
elif has_cublaslt:
binary_name = f"libbitsandbytes_cuda{cuda_version_string}.so"
else:
return f"{bin_base_name}{cuda_version_string}_nocublaslt.so"
binary_name = get_binary_name()
"if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so"
binary_name = f"libbitsandbytes_cuda{cuda_version_string}_nocublaslt.so"
return binary_name, cudart_path, cuda, cc, cuda_version_string

View File

@ -1,6 +1,7 @@
import errno
from pathlib import Path
from typing import Set, Union
from bitsandbytes.cextension import CUDASetup
from .env_vars import get_potentially_lib_path_containing_env_vars

View File

@ -1,26 +0,0 @@
import typer
cli = typer.Typer()
@cli.callback()
def callback():
"""
Awesome Portal Gun
"""
@cli.command()
def shoot():
"""
Shoot the portal gun
"""
typer.echo("Shooting portal gun")
@cli.command()
def load():
"""
Load the portal gun
"""
typer.echo("Loading portal gun")

View File

@ -3,15 +3,19 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import ctypes as ct
import itertools
import operator
import random
import torch
import itertools
import math
from functools import reduce # Required in Python 3
from typing import Tuple
from torch import Tensor
from .cextension import COMPILED_WITH_CUDA, lib
from functools import reduce # Required in Python 3
# math.prod not compatible with python < 3.8
def prod(iterable):
@ -82,7 +86,7 @@ if COMPILED_WITH_CUDA:
)
class CUBLAS_Context(object):
class CUBLAS_Context:
_instance = None
def __init__(self):
@ -112,7 +116,7 @@ class CUBLAS_Context(object):
return self.context[device.index]
class Cusparse_Context(object):
class Cusparse_Context:
_instance = None
def __init__(self):
@ -129,14 +133,73 @@ class Cusparse_Context(object):
return cls._instance
def create_linear_map(signed=True):
if signed:
return torch.linspace(-1.0, 1.0, 256)
def create_linear_map(signed=True, total_bits=8, add_zero=True):
sign = (-1.0 if signed else 0.0)
total_values = 2**total_bits
if add_zero or total_bits < 8:
# add a zero
# since we simulate less bits by having zeros in the data type, we
# we need to center the quantization around zero and as such lose
# a single value
total_values = (2**total_bits if not signed else 2**total_bits-1)
values = torch.linspace(sign, 1.0, total_values)
gap = 256 - values.numel()
if gap == 0:
return values
else:
return torch.linspace(0.0, 1.0, 256)
l = values.numel()//2
#return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist())
return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist())
def create_dynamic_map(signed=True, n=7):
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8):
e = exponent_bits
p = precision_bits
has_sign = 1 if signed else 0
assert e+p == total_bits-has_sign
# the exponent is biased to 2^(e-1) -1 == 0
evalues = []
pvalues = []
for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)):
evalues.append(2**val)
values = []
lst = list(itertools.product([0, 1], repeat=precision_bits))
#for ev in evalues:
bias = 2**(exponent_bits-1)-1
for evalue in range(2**(exponent_bits)):
for bit_pattern in lst:
value = (1 if evalue != 0 else 0)
for i, pval in enumerate(list(bit_pattern)):
value += pval*(2**-(i+1))
if evalue == 0:
# subnormals
value = value*2**-(bias-1)
else:
# normals
value = value*2**-(evalue-bias-2)
values.append(value)
if signed:
values.append(-value)
assert len(values) == 2**total_bits
values.sort()
if total_bits < 8:
gap = 256 - len(values)
for i in range(gap):
values.append(0)
values.sort()
code = torch.Tensor(values)
code /= code.max()
return code
def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
"""
Creates the dynamic quantiztion map.
@ -157,40 +220,57 @@ def create_dynamic_map(signed=True, n=7):
# these are additional items that come from the case
# where all the exponent bits are zero and no
# indicator bit is present
additional_items = 2 ** (7 - n) - 1
non_sign_bits = total_bits - (1 if signed else 0)
additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1
if not signed:
additional_items = 2 * additional_items
for i in range(n):
fraction_items = (
2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1
)
for i in range(max_exponent_bits):
fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1))
boundaries = torch.linspace(0.1, 1, fraction_items)
means = (boundaries[:-1] + boundaries[1:]) / 2.0
data += ((10 ** (-(n - 1) + i)) * means).tolist()
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
if signed:
data += (-(10 ** (-(n - 1) + i)) * means).tolist()
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
if additional_items > 0:
boundaries = torch.linspace(0.1, 1, additional_items + 1)
means = (boundaries[:-1] + boundaries[1:]) / 2.0
data += ((10 ** (-(n - 1) + i)) * means).tolist()
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
if signed:
data += (-(10 ** (-(n - 1) + i)) * means).tolist()
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
data.append(0)
data.append(1.0)
gap = 256 - len(data)
for i in range(gap):
data.append(0)
data.sort()
return Tensor(data)
def create_quantile_map(A, total_bits=8):
q = estimate_quantiles(A, num_quantiles=2**total_bits-1)
q = q.tolist()
q.append(0)
gap = 256 - len(q)
for i in range(gap):
q.append(0)
q.sort()
q = Tensor(q)
q = q/q.abs().max()
return q
def get_special_format_str():
if not torch.cuda.is_available(): return 'col_turing'
major, minor = torch.cuda.get_device_capability()
major, _minor = torch.cuda.get_device_capability()
if major <= 7:
return "col_turing"
elif major == 8:
if major == 8:
return "col_ampere"
else:
return "col_turing"
@ -318,16 +398,12 @@ def nvidia_transform(
dim2 = ct.c_int32(shape[2])
ptr = CUBLAS_Context.get_instance().get_context(A.device)
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
func(ptr, get_ptr(A), get_ptr(out), dim1, dim2)
return out, new_state
def estimate_quantiles(
A: Tensor, out: Tensor = None, offset: float = 1 / 512
) -> Tensor:
def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor:
'''
Estimates 256 equidistant quantiles on the input tensor eCDF.
@ -347,25 +423,37 @@ def estimate_quantiles(
out : torch.Tensor
Tensor with the 256 estimated quantiles.
offset : float
The offset for the first and last quantile from 0 and 1. Default: 1/512
The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles)
num_quantiles : int
The number of equally spaced quantiles.
Returns
-------
torch.Tensor:
The 256 quantiles in float32 datatype.
'''
if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.')
if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}")
if num_quantiles < 256 and offset == 1/(512):
# override default arguments
offset = 1/(2*num_quantiles)
if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
is_on_gpu([A, out])
device = pre_call(A.device)
if A.dtype == torch.float32:
lib.cestimate_quantiles_fp32(
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
)
lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
elif A.dtype == torch.float16:
lib.cestimate_quantiles_fp16(
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
)
lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
else:
raise NotImplementedError(f"Not supported data type {A.dtype}")
post_call(device)
if num_quantiles < 256:
step = round(256/num_quantiles)
idx = torch.linspace(0, 255, num_quantiles).long().to(A.device)
out = out[idx]
return out
@ -398,15 +486,14 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
The quantization state to undo the quantization.
"""
if code is None:
if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
code = name2qmap["dynamic"]
code = code.to(A.device)
if absmax is None:
n = A.numel()
blocksize = (blocksize if A.device.type == 'cpu' else 4096)
blocks = n // blocksize
blocks += 1 if n % blocksize > 0 else 0
absmax = torch.zeros((blocks,), device=A.device)
@ -415,8 +502,13 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
out = torch.zeros_like(A, dtype=torch.uint8)
if A.device.type != 'cpu':
is_on_gpu([code, A, absmax, out, rand])
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
cblocksize = ct.c_int32(blocksize)
prev_device = pre_call(A.device)
code = code.to(A.device)
if rand is not None:
is_on_gpu([code, A, out, absmax, rand])
assert blocksize==4096
assert rand.numel() >= 1024
rand_offset = random.randint(0, 1023)
if A.dtype == torch.float32:
@ -424,20 +516,19 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
elif A.dtype == torch.float16:
lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
else:
raise ValueError(
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
)
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
else:
is_on_gpu([code, A, out, absmax])
if A.dtype == torch.float32:
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
elif A.dtype == torch.float16:
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
else:
raise ValueError(
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
)
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device)
else:
# cpu
code = code.cpu()
assert rand is None
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
@ -482,27 +573,30 @@ def dequantize_blockwise(
if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
code = name2qmap["dynamic"]
code = code.to(A.device)
if out is None:
out = torch.zeros_like(A, dtype=torch.float32)
if quant_state is None:
quant_state = (absmax, code)
else:
absmax, code = quant_state
if A.device.type != 'cpu':
if blocksize not in [2048, 4096]:
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]")
device = pre_call(A.device)
code = code.to(A.device)
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
is_on_gpu([A, out])
if out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
elif out.dtype == torch.float16:
lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
else:
raise ValueError(
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
)
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device)
else:
code = code.cpu()
lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
return out
@ -958,7 +1052,7 @@ def histogram_scatter_add_2d(
maxdim1 = ct.c_int32(histogram.shape[0])
n = ct.c_int32(index1.numel())
is_on_gpu([histogram, index1, index2d, source])
is_on_gpu([histogram, index1, index2, source])
lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
@ -1417,7 +1511,7 @@ def get_colrow_absmax(
return row_stats, col_stats, nnz_block_ptr
class COOSparseTensor(object):
class COOSparseTensor:
def __init__(self, rows, cols, nnz, rowidx, colidx, values):
assert rowidx.dtype == torch.int32
assert colidx.dtype == torch.int32
@ -1434,7 +1528,7 @@ class COOSparseTensor(object):
self.values = values
class CSRSparseTensor(object):
class CSRSparseTensor:
def __init__(self, rows, cols, nnz, rowptr, colidx, values):
assert rowptr.dtype == torch.int32
assert colidx.dtype == torch.int32
@ -1451,7 +1545,7 @@ class CSRSparseTensor(object):
self.values = values
class CSCSparseTensor(object):
class CSCSparseTensor:
def __init__(self, rows, cols, nnz, colptr, rowidx, values):
assert colptr.dtype == torch.int32
assert rowidx.dtype == torch.int32
@ -1615,8 +1709,6 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
dim1 = ct.c_int32(shape[0] * shape[1])
dim2 = ct.c_int32(shape[2])
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
is_on_gpu([A, out])
if to_order == 'col32':
if transpose:

View File

@ -2,24 +2,11 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import (
Any,
Callable,
Dict,
Iterator,
Mapping,
Optional,
Set,
Tuple,
TypeVar,
Union,
overload,
)
from typing import Optional, TypeVar, Union, overload
import torch
import torch.nn.functional as F
from torch import Tensor, device, dtype, nn
from torch.nn.parameter import Parameter
import bitsandbytes as bnb
from bitsandbytes.optim import GlobalOptimManager
@ -39,7 +26,7 @@ class StableEmbedding(torch.nn.Embedding):
sparse: bool = False,
_weight: Optional[Tensor] = None,
) -> None:
super(StableEmbedding, self).__init__(
super().__init__(
num_embeddings,
embedding_dim,
padding_idx,
@ -96,7 +83,7 @@ class Embedding(torch.nn.Embedding):
sparse: bool = False,
_weight: Optional[Tensor] = None,
) -> None:
super(Embedding, self).__init__(
super().__init__(
num_embeddings,
embedding_dim,
padding_idx,
@ -225,7 +212,7 @@ class Linear8bitLt(nn.Linear):
threshold=0.0,
index=None,
):
super(Linear8bitLt, self).__init__(
super().__init__(
input_features, output_features, bias
)
self.state = bnb.MatmulLtState()

View File

@ -5,12 +5,11 @@
from bitsandbytes.cextension import COMPILED_WITH_CUDA
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
from .adam import Adam, Adam8bit, Adam32bit
from .adamw import AdamW, AdamW8bit, AdamW32bit
from .sgd import SGD, SGD8bit, SGD32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .lamb import LAMB, LAMB8bit, LAMB32bit
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .optimizer import GlobalOptimManager
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
from .sgd import SGD, SGD8bit, SGD32bit

View File

@ -21,18 +21,18 @@ class Adagrad(Optimizer1State):
block_wise=True,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay:
raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
f"Invalid weight_decay value: {weight_decay}"
)
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
raise ValueError(f"Invalid epsilon value: {eps}")
if initial_accumulator_value != 0.0:
raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0:
raise ValueError("Lr Decay != 0.0 not supported!")
super(Adagrad, self).__init__(
super().__init__(
"adagrad",
params,
lr,
@ -63,19 +63,19 @@ class Adagrad8bit(Optimizer1State):
block_wise=True,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay:
raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
f"Invalid weight_decay value: {weight_decay}"
)
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
raise ValueError(f"Invalid epsilon value: {eps}")
if initial_accumulator_value != 0.0:
raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0:
raise ValueError("Lr Decay != 0.0 not supported!")
assert block_wise
super(Adagrad8bit, self).__init__(
super().__init__(
"adagrad",
params,
lr,
@ -106,18 +106,18 @@ class Adagrad32bit(Optimizer1State):
block_wise=True,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay:
raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
f"Invalid weight_decay value: {weight_decay}"
)
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
raise ValueError(f"Invalid epsilon value: {eps}")
if initial_accumulator_value != 0.0:
raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0:
raise ValueError("Lr Decay != 0.0 not supported!")
super(Adagrad32bit, self).__init__(
super().__init__(
"adagrad",
params,
lr,

View File

@ -28,7 +28,7 @@ class Adam(Optimizer2State):
percentile_clipping=100,
block_wise=True,
):
super(Adam, self).__init__(
super().__init__(
"adam",
params,
lr,
@ -57,7 +57,7 @@ class Adam8bit(Optimizer2State):
percentile_clipping=100,
block_wise=True,
):
super(Adam8bit, self).__init__(
super().__init__(
"adam",
params,
lr,
@ -86,7 +86,7 @@ class Adam32bit(Optimizer2State):
percentile_clipping=100,
block_wise=True,
):
super(Adam32bit, self).__init__(
super().__init__(
"adam",
params,
lr,
@ -146,7 +146,7 @@ class AnalysisAdam(torch.optim.Optimizer):
weight_decay=weight_decay,
amsgrad=amsgrad,
)
super(AnalysisAdam, self).__init__(params, defaults)
super().__init__(params, defaults)
self.analysis = bnb_analysis
self.savedir = savedir

View File

@ -20,7 +20,7 @@ class AdamW(Optimizer2State):
percentile_clipping=100,
block_wise=True,
):
super(AdamW, self).__init__(
super().__init__(
"adam",
params,
lr,
@ -49,7 +49,7 @@ class AdamW8bit(Optimizer2State):
percentile_clipping=100,
block_wise=True,
):
super(AdamW8bit, self).__init__(
super().__init__(
"adam",
params,
lr,
@ -78,7 +78,7 @@ class AdamW32bit(Optimizer2State):
percentile_clipping=100,
block_wise=True,
):
super(AdamW32bit, self).__init__(
super().__init__(
"adam",
params,
lr,

View File

@ -23,7 +23,7 @@ class LAMB(Optimizer2State):
block_wise=False,
max_unorm=1.0,
):
super(LAMB, self).__init__(
super().__init__(
"lamb",
params,
lr,
@ -56,7 +56,7 @@ class LAMB8bit(Optimizer2State):
block_wise=False,
max_unorm=1.0,
):
super(LAMB8bit, self).__init__(
super().__init__(
"lamb",
params,
lr,
@ -89,7 +89,7 @@ class LAMB32bit(Optimizer2State):
block_wise=False,
max_unorm=1.0,
):
super(LAMB32bit, self).__init__(
super().__init__(
"lamb",
params,
lr,

View File

@ -25,9 +25,9 @@ class LARS(Optimizer1State):
):
if momentum == 0:
raise NotImplementedError(
f"LARS without momentum is not supported!"
"LARS without momentum is not supported!"
)
super(LARS, self).__init__(
super().__init__(
"lars",
params,
lr,
@ -59,9 +59,9 @@ class LARS8bit(Optimizer1State):
):
if momentum == 0:
raise NotImplementedError(
f"LARS without momentum is not supported!"
"LARS without momentum is not supported!"
)
super(LARS8bit, self).__init__(
super().__init__(
"lars",
params,
lr,
@ -93,9 +93,9 @@ class LARS32bit(Optimizer1State):
):
if momentum == 0:
raise NotImplementedError(
f"LARS without momentum is not supported!"
"LARS without momentum is not supported!"
)
super(LARS32bit, self).__init__(
super().__init__(
"lars",
params,
lr,
@ -123,12 +123,12 @@ class PytorchLARS(Optimizer):
max_unorm=0.02,
):
if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
raise ValueError(f"Invalid learning rate: {lr}")
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
raise ValueError(f"Invalid momentum value: {momentum}")
if weight_decay < 0.0:
raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
f"Invalid weight_decay value: {weight_decay}"
)
defaults = dict(
@ -143,10 +143,10 @@ class PytorchLARS(Optimizer):
raise ValueError(
"Nesterov momentum requires a momentum and zero dampening"
)
super(PytorchLARS, self).__init__(params, defaults)
super().__init__(params, defaults)
def __setstate__(self, state):
super(PytorchLARS, self).__setstate__(state)
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("nesterov", False)
@ -181,7 +181,7 @@ class PytorchLARS(Optimizer):
state = self.state[p]
d_p = p.grad
if weight_decay != 0:
d_p = d_p.add(param, alpha=weight_decay)
d_p = d_p.add(p, alpha=weight_decay)
if momentum != 0:
buf = state.get("momentum_buffer", None)

View File

@ -12,13 +12,13 @@ import torch
import bitsandbytes.functional as F
class MockArgs(object):
class MockArgs:
def __init__(self, initial_data):
for key in initial_data:
setattr(self, key, initial_data[key])
class GlobalOptimManager(object):
class GlobalOptimManager:
_instance = None
def __init__(self):
@ -56,9 +56,9 @@ class GlobalOptimManager(object):
"""
Overrides initial optimizer config for specific parameters.
The key-values of the optimizer config for the input parameters are overidden
The key-values of the optimizer config for the input parameters are overridden
This can be both, optimizer parameters like "betas", or "lr" or it can be
8-bit specific paramters like "optim_bits", "percentile_clipping".
8-bit specific parameters like "optim_bits", "percentile_clipping".
Parameters
----------
@ -93,13 +93,12 @@ class GlobalOptimManager(object):
class Optimizer8bit(torch.optim.Optimizer):
def __init__(self, params, defaults, optim_bits=32):
super(Optimizer8bit, self).__init__(params, defaults)
super().__init__(params, defaults)
self.initialized = False
self.name2qmap = {}
self.mng = GlobalOptimManager.get_instance()
self.non_castable_tensor_keys = set(
[
self.non_castable_tensor_keys = {
"qmap1",
"qmap2",
"max1",
@ -112,8 +111,7 @@ class Optimizer8bit(torch.optim.Optimizer):
"absmax1",
"absmax2",
"unorm_vec",
]
)
}
if optim_bits == 8:
self.fill_qmap()
@ -123,7 +121,7 @@ class Optimizer8bit(torch.optim.Optimizer):
self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False)
def __setstate__(self, state):
super(Optimizer8bit, self).__setstate__(state)
super().__setstate__(state)
def load_state_dict(self, state_dict):
r"""Loads the optimizer state.
@ -155,8 +153,8 @@ class Optimizer8bit(torch.optim.Optimizer):
id_map = {
old_id: p
for old_id, p in zip(
chain.from_iterable((g["params"] for g in saved_groups)),
chain.from_iterable((g["params"] for g in groups)),
chain.from_iterable(g["params"] for g in saved_groups),
chain.from_iterable(g["params"] for g in groups),
)
}
@ -284,11 +282,11 @@ class Optimizer8bit(torch.optim.Optimizer):
return config
def init_state(self, group, p, gindex, pindex):
raise NotImplementedError(f"init_state method needs to be overidden")
raise NotImplementedError("init_state method needs to be overridden")
def update_step(self, group, p, gindex, pindex):
raise NotImplementedError(
f"The update_step method needs to be overidden"
"The update_step method needs to be overridden"
)
@ -310,9 +308,9 @@ class Optimizer2State(Optimizer8bit):
skip_zeros=False,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
raise ValueError(f"Invalid epsilon value: {eps}")
if isinstance(betas, str):
# format: '(beta1, beta2)'
betas = betas.replace("(", "").replace(")", "").strip().split(",")
@ -324,10 +322,10 @@ class Optimizer2State(Optimizer8bit):
)
if not 0.0 <= weight_decay:
raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
f"Invalid weight_decay value: {weight_decay}"
)
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer2State, self).__init__(params, defaults, optim_bits)
super().__init__(params, defaults, optim_bits)
if args is None:
args = {}
@ -542,9 +540,9 @@ class Optimizer1State(Optimizer8bit):
skip_zeros=False,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
raise ValueError(f"Invalid epsilon value: {eps}")
for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0:
raise ValueError(
@ -552,10 +550,10 @@ class Optimizer1State(Optimizer8bit):
)
if not 0.0 <= weight_decay:
raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
f"Invalid weight_decay value: {weight_decay}"
)
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer1State, self).__init__(params, defaults, optim_bits)
super().__init__(params, defaults, optim_bits)
if args is None:
args = {}

View File

@ -23,11 +23,11 @@ class RMSprop(Optimizer1State):
):
if alpha == 0:
raise NotImplementedError(
f"RMSprop with alpha==0.0 is not supported!"
"RMSprop with alpha==0.0 is not supported!"
)
if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop, self).__init__(
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__(
"rmsprop",
params,
lr,
@ -59,11 +59,11 @@ class RMSprop8bit(Optimizer1State):
):
if alpha == 0:
raise NotImplementedError(
f"RMSprop with alpha==0.0 is not supported!"
"RMSprop with alpha==0.0 is not supported!"
)
if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop8bit, self).__init__(
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__(
"rmsprop",
params,
lr,
@ -96,11 +96,11 @@ class RMSprop32bit(Optimizer1State):
if alpha == 0:
raise NotImplementedError(
f"RMSprop with alpha==0.0 is not supported!"
"RMSprop with alpha==0.0 is not supported!"
)
if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop32bit, self).__init__(
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__(
"rmsprop",
params,
lr,

View File

@ -21,8 +21,8 @@ class SGD(Optimizer1State):
block_wise=True,
):
if momentum == 0:
raise NotImplementedError(f"SGD without momentum is not supported!")
super(SGD, self).__init__(
raise NotImplementedError("SGD without momentum is not supported!")
super().__init__(
"momentum",
params,
lr,
@ -52,8 +52,8 @@ class SGD8bit(Optimizer1State):
block_wise=True,
):
if momentum == 0:
raise NotImplementedError(f"SGD without momentum is not supported!")
super(SGD8bit, self).__init__(
raise NotImplementedError("SGD without momentum is not supported!")
super().__init__(
"momentum",
params,
lr,
@ -83,8 +83,8 @@ class SGD32bit(Optimizer1State):
block_wise=True,
):
if momentum == 0:
raise NotImplementedError(f"SGD without momentum is not supported!")
super(SGD32bit, self).__init__(
raise NotImplementedError("SGD without momentum is not supported!")
super().__init__(
"momentum",
params,
lr,

View File

@ -428,16 +428,16 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
}
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC>
__launch_bounds__(TH, 4)
//__launch_bounds__(TH, 4)
__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n)
{
const int n_full = gridDim.x * BLOCK_SIZE;
int valid_items = 0;
const int base_idx = (blockIdx.x * BLOCK_SIZE);
T vals[NUM];
float rand_vals[NUM];
unsigned char qvals[NUM];
T vals[NUM_PER_TH];
float rand_vals[NUM_PER_TH];
unsigned char qvals[NUM_PER_TH];
//float local_abs_max = -FLT_MAX;
float local_abs_max = 0.0f;
int local_rand_idx = 0;
@ -454,8 +454,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
__shared__ float smem_code[256];
__shared__ float smem_absmax_value[1];
if(threadIdx.x < 256)
smem_code[threadIdx.x] = code[threadIdx.x];
for(int i = threadIdx.x; i < 256; i+=blockDim.x)
smem_code[i] = code[i];
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{
@ -510,15 +510,15 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
}
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH>
__global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, T *out, const int n)
__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n)
{
const int n_full = gridDim.x * BLOCK_SIZE;
int valid_items = 0;
const int base_idx = (blockIdx.x * BLOCK_SIZE);
T vals[NUM];
unsigned char qvals[NUM];
T vals[NUM_PER_TH];
unsigned char qvals[NUM_PER_TH];
float local_abs_max = -FLT_MAX;
typedef cub::BlockLoad<unsigned char, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
@ -526,10 +526,11 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c
__shared__ typename LoadChar::TempStorage loadchar;
__shared__ typename StoreT::TempStorage storet;
__shared__ float smem_code[256];
//__shared__ float smem_code[256];
//float local_code[16];
if(threadIdx.x < 256)
smem_code[threadIdx.x] = code[threadIdx.x];
//if(threadIdx.x < 256)
//smem_code[threadIdx.x] = code[threadIdx.x];
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{
@ -539,9 +540,10 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c
__syncthreads();
LoadChar(loadchar).Load(&(A[i]), qvals, valid_items, 128);
// load code through read-only cache via __ldg
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
vals[j] = smem_code[qvals[j]]*local_abs_max;
vals[j] = __ldg(&code[qvals[j]])*local_abs_max;
__syncthreads();
StoreT(storet).Store(&(out[i]), vals, valid_items);
@ -2791,11 +2793,33 @@ template __global__ void kQuantizeBlockwise<half, 4096, 4, 0>(float * code, half
template __global__ void kQuantizeBlockwise<float, 4096, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 4096, 4, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 4096, 4, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 2048, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 2048, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 1024, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 1024, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 512, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 512, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 256, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 256, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 128, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 128, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 64, 1, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 64, 1, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 2048, 512, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 2048, 512, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 1024, 256, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 1024, 256, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 512, 256, 2>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 512, 256, 2>(float *code, unsigned char * A, float * absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 256, 128, 2>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 256, 128, 2>(float *code, unsigned char * A, float * absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 128, 64, 2>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 128, 64, 2>(float *code, unsigned char * A, float * absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 64, 64, 1>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 64, 64, 1>(float *code, unsigned char * A, float * absmax, float *out, const int n);

View File

@ -15,7 +15,7 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n);
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC> __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, T *out, const int n);
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n);
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
@ -121,5 +121,3 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
#endif

View File

@ -50,11 +50,29 @@ void dequantize(float *code, unsigned char *A, float *out, int n)
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n)
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n)
{
int num_blocks = n/4096;
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
int num_blocks = n/blocksize;
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
if(STOCHASTIC == 1)
assert(blocksize == 4096);
if(blocksize == 4096)
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 2048)
kQuantizeBlockwise<T, 2048, 4, 0><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 1024)
kQuantizeBlockwise<T, 1024, 4, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 512)
kQuantizeBlockwise<T, 512, 2, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 256)
kQuantizeBlockwise<T, 256, 2, 0><<<num_blocks, 128>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 128)
kQuantizeBlockwise<T, 128, 2, 0><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 64)
kQuantizeBlockwise<T, 64, 1, 0><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
@ -66,6 +84,17 @@ template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, flo
kDequantizeBlockwise<T, 4096, 1024, 4><<<num_blocks, 4096/4>>>(code, A, absmax, out, n);
else if(blocksize == 2048)
kDequantizeBlockwise<T, 2048, 512, 4><<<num_blocks, 2048/4>>>(code, A, absmax, out, n);
else if(blocksize == 1024)
kDequantizeBlockwise<T, 1024, 256, 4><<<num_blocks, 1024/4>>>(code, A, absmax, out, n);
else if(blocksize == 512)
kDequantizeBlockwise<T, 512, 256, 2><<<num_blocks, 512/2>>>(code, A, absmax, out, n);
else if(blocksize == 256)
kDequantizeBlockwise<T, 256, 128, 2><<<num_blocks, 256/2>>>(code, A, absmax, out, n);
else if(blocksize == 128)
kDequantizeBlockwise<T, 128, 64, 2><<<num_blocks, 128/2>>>(code, A, absmax, out, n);
else if(blocksize == 64)
kDequantizeBlockwise<T, 64, 64, 1><<<num_blocks, 64/1>>>(code, A, absmax, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
@ -659,10 +688,10 @@ template void transformRowToFormat<COL_AMPERE, 1>(char * A, char *out, int rows,
template void estimateQuantiles(half *A, float *code, float offset, int n);
template void estimateQuantiles(float *A, float *code, float offset, int n);
template void quantizeBlockwise<half, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
template void quantizeBlockwise<float, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
template void quantizeBlockwise<half, 1>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
template void quantizeBlockwise<float, 1>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
template void quantizeBlockwise<half, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<half, 1>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 1>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void dequantizeBlockwise<half>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
template void dequantizeBlockwise<float>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);

View File

@ -128,7 +128,7 @@ template <typename T> void estimateQuantiles(T *A, float *code, float offset, in
void quantize(float *code, float *A, unsigned char *out, int n);
void dequantize(float *code, unsigned char *A, float *out, int n);
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n);
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,

View File

@ -75,10 +75,10 @@ MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32)
void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping<half>(g, gnorm_vec, step, n); }
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise<half, 0>(code, A, absmax, out, NULL, 0, n); }
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise<float, 0>(code, A, absmax, out, NULL, 0, n); }
void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<half, 1>(code, A, absmax, out, rand, rand_offset, n); }
void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<float, 1>(code, A, absmax, out, rand, rand_offset, n); }
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<half, 1>(code, A, absmax, out, rand, rand_offset, 4096, n); }
void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<float, 1>(code, A, absmax, out, rand, rand_offset, 4096, n); }
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half>(code, A, absmax, out, blocksize, n); } \
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); }
@ -140,8 +140,8 @@ extern "C"
void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); }
void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); }
void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); }
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, n); }
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, n); }
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, n); }
void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, n); }
@ -290,4 +290,3 @@ extern "C"
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); }
}

View File

@ -76,6 +76,3 @@ if [[ -n "$CUDA_VERSION" ]]; then
else
echo ""
fi

View File

@ -18,7 +18,7 @@ def read(fname):
setup(
name=f"bitsandbytes",
version=f"0.35.3",
version=f"0.35.4",
author="Tim Dettmers",
author_email="dettmers@cs.washington.edu",
description="8-bit optimizers and matrix multiplication routines.",
@ -26,9 +26,6 @@ setup(
keywords="gpu optimizers optimization 8-bit quantization compression",
url="https://github.com/TimDettmers/bitsandbytes",
packages=find_packages(),
entry_points={
"console_scripts": ["debug_cuda = bitsandbytes.debug_cli:cli"],
},
package_data={"": libs},
long_description=read("README.md"),
long_description_content_type="text/markdown",

View File

@ -1,4 +1,4 @@
from itertools import product, permutations
from itertools import permutations, product
import pytest
import torch
@ -27,7 +27,7 @@ str_values = list(
)
)
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format(
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(
*vals
)
for vals in str_values
@ -286,7 +286,7 @@ str_values = list(
has_bias
)
)
names = ["dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}_has_bias_{10}".format(*vals) for vals in str_values]
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_decomp_{}_has_fp16_weights_{}_has_bias_{}".format(*vals) for vals in str_values]
@pytest.mark.parametrize(

View File

@ -1,13 +1,13 @@
import os
import pytest
import bitsandbytes as bnb
from typing import List, NamedTuple
import pytest
import bitsandbytes as bnb
from bitsandbytes.cuda_setup import (
CUDA_RUNTIME_LIB,
evaluate_cuda_setup,
determine_cuda_runtime_lib_path,
evaluate_cuda_setup,
extract_candidate_paths,
)

View File

@ -6,12 +6,14 @@ from itertools import product
import einops
import pytest
import torch
import numpy as np
import bitsandbytes as bnb
from bitsandbytes import functional as F
from scipy.stats import norm
torch.set_printoptions(
precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
)
k = 20
@ -26,7 +28,7 @@ def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0):
class FFN(torch.nn.Module):
def __init__(self, input_features, hidden_size, bias=True):
super(FFN, self).__init__()
super().__init__()
self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias)
self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias)
@ -40,7 +42,7 @@ class FFN(torch.nn.Module):
return x
class Timer(object):
class Timer:
def __init__(self):
self.starts = {}
self.ends = {}
@ -67,7 +69,7 @@ class Timer(object):
self.ends.pop(name)
if print_ms and name in self.agg:
print("{0} took: {1:.5f}s".format(name, self.agg[name] / 1000.0))
print(f"{name} took: {self.agg[name] / 1000.0:.5f}s")
return self.agg[name]
@ -149,30 +151,41 @@ def test_dynamic_quantization():
def test_dynamic_blockwise_quantization():
#print('')
for blocksize in [4096, 2048, 1024, 512]:
diffs = []
reldiffs = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, S = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, S)
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diffs[-1] < 0.011
# print(sum(diffs)/len(diffs))
# print(sum(reldiffs)/len(reldiffs))
abserr = sum(diffs)/len(diffs)
relerr = sum(reldiffs)/len(reldiffs)
assert abserr < 0.011
assert relerr < 0.018
#print('randn', blocksize, sum(diffs)/len(diffs))
#print('randn', blocksize, sum(reldiffs)/len(reldiffs))
diffs = []
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda")
C, S = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, S)
diff = torch.abs(A1 - A2).mean().item()
assert diff < 0.0033
diffs.append(diff)
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
# print(sum(diffs)/len(diffs))
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
diff = torch.abs(A1 - A2)
reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
#torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
abserr = sum(diffs)/len(diffs)
relerr = sum(reldiffs)/len(reldiffs)
assert abserr < 0.0035
assert relerr < 0.015
#print('rand', blocksize, sum(diffs)/len(diffs))
#print('rand', blocksize, sum(reldiffs)/len(reldiffs))
def test_dynamic_blockwise_stochastic_quantization():
@ -289,7 +302,7 @@ batched = [False, True]
values = list(product(dim1, dim2, methods, batched))
values_names = list(product(dim1, dim2, method_names, batched))
names = [
"dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals)
"dim1_{}_dim2_{}_quant_{}_batched_{}".format(*vals)
for vals in values_names
]
@ -347,7 +360,7 @@ seq_dim = torch.randint(16, 256, size=(n,)).tolist()
transpose = [(False, False), (False, True), (True, False), (True, True)]
values = list(product(hidden_dim, batch_dim, transpose, seq_dim))
names = [
"hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}".format(*vals)
"hidden_dim_{}_batch_dim_{},transpose_{}_seq_dim_{}".format(*vals)
for vals in values
]
@ -412,7 +425,7 @@ hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
values = list(product(seq_dim, hidden_dim, batch_dim))
names = [
"seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values
"seq_dim{}_hidden_dim{}_batch_dim{}".format(*vals) for vals in values
]
@ -444,7 +457,7 @@ batch_dim = torch.randint(2, 16, size=(n,)).tolist()
transpose = [False, True]
values = list(product(seq_dim, hidden_dim, batch_dim, transpose))
names = [
"seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}".format(*vals)
"seq_dim={}_hidden_dim={}_batch_dim={}_transpose{}".format(*vals)
for vals in values
]
@ -529,7 +542,7 @@ dim4 = torch.randint(32, 256, size=(n,)).tolist()
transpose = [(False, False), (True, False), (False, True), (True, True)]
values = list(product(dim1, dim2, dim3, dim4, transpose))
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals)
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_transpose_{}".format(*vals)
for vals in values
]
@ -567,7 +580,7 @@ dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist()
values = list(product(dim1, dim2, dim3))
names = ["dim1_{0}_dim2_{1}_dim3_{2}".format(*vals) for vals in values]
names = ["dim1_{}_dim2_{}_dim3_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names)
@ -596,7 +609,7 @@ transpose = [False]
dims = [2, 3]
values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))
names = ["dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}".format(*vals)for vals in values]
names = ["dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_transpose_{}".format(*vals)for vals in values]
@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",values,ids=names)
@ -678,7 +691,7 @@ ldb = [0]
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}".format(*vals)
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}_ldb_{}".format(*vals)
for vals in values
]
@ -726,7 +739,7 @@ dims = (2,)
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims))
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals)
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}".format(*vals)
for vals in values
]
@ -784,7 +797,7 @@ values = [
# values = list(product(batch, seq, model, hidden))
names = [
"batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
"batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values
]
@ -952,7 +965,7 @@ dims = (2,)
formatB = ["col_turing", "col_ampere"]
has_bias = [True, False]
values = list(product(dim1, dim4, dims, formatB, has_bias))
names = ["dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}_has_bias_{4}".format(*vals) for vals in values]
names = ["dim1_{}_dim4_{}_dims_{}_formatB_{}_has_bias_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names)
@ -1002,7 +1015,7 @@ dim2 = [1 * 1024]
dims = (2,)
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dims))
names = ["dim1_{0}_dim2_{1}_dims_{2}".format(*vals) for vals in values]
names = ["dim1_{}_dim2_{}_dims_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names)
@ -1058,7 +1071,7 @@ dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
@ -1105,7 +1118,7 @@ dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(zip(dim1, dim4, inner))
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
@ -1149,7 +1162,7 @@ dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(zip(dim1, dim4, inner))
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
@ -1224,7 +1237,7 @@ inner = [12288 * 4, 4096 * 4]
dim4 = [12288, 4096]
values = list(zip(dim1, dim4, inner))
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
@ -1290,7 +1303,7 @@ values = list(
product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
)
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}".format(
"dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_{}".format(
*vals
)
for vals in values
@ -1341,7 +1354,7 @@ a_order = ["col_turing"]
out_order = ["row"]
values = list(product(dim1, dim2, dtype, a_order, out_order))
names = [
"dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}".format(*vals)
"dim1_{}_dim2_{}_dtype_{}_orderA_{}_orderOut_{}".format(*vals)
for vals in values
]
@ -1367,7 +1380,7 @@ dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
# dim2 = [5]
values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
@ -1404,7 +1417,7 @@ dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
# dim2 = [11]
transposed_B = [False, True]
values = list(product(dim1, dim2, transposed_B))
names = ["dim1_{0}_dim2_{1}_transposed_B_{2}".format(*vals) for vals in values]
names = ["dim1_{}_dim2_{}_transposed_B_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names)
@ -1485,7 +1498,7 @@ n = 2
dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
@ -1550,7 +1563,7 @@ dtype = [torch.float16]
out_function = ["zeros", "ones"]
values = list(product(dim1, dim2, dtype, out_function))
names = [
"dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values
"dim1_{}_dim2_{}_dtype_{}_out_func_{}".format(*vals) for vals in values
]
@ -1616,17 +1629,6 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
# print(time.time() - t0)
def test_layout():
a1 = torch.rand(16, 64, device="cuda", dtype=torch.float16)
a1 = torch.arange(16 * 64, device="cuda").reshape(16, 64).byte()
a2, s2 = F.transform(a1, "col_turing")
print(a2.shape)
print(a1.flatten()[8 * 64 : 8 * 64 + 32])
for i in range(4):
print(a2.flatten()[i * 8 * 32 : i * 8 * 32 + 32], 0)
def test_coo2csr():
threshold = 1
A = torch.randn(128, 128).half().cuda()
@ -1678,7 +1680,7 @@ dim2 = [2048]
# dim2 = [2]
dtype = [torch.int8]
values = list(product(dim1, dim2, dtype))
names = ["dim1_{0}_dim2_{1}_dtype_{2}".format(*vals) for vals in values]
names = ["dim1_{}_dim2_{}_dtype_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names)
@ -1794,7 +1796,7 @@ values.append((batch_size, seqdim, 768, 4 * 768))
# values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 12288, 4*12288))
names = [
"batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
"batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values
]
@ -2040,3 +2042,154 @@ def test_blockwise_cpu_large():
assert diffs[-1] < 0.011
# print(sum(diffs)/len(diffs))
# print(sum(reldiffs)/len(reldiffs))
def test_fp8_quant():
for e_bits in range(1, 7):
p_bits = 7-e_bits
code = F.create_fp8_map(True, e_bits, p_bits).cuda()
print(e_bits, p_bits)
abserr = []
relerr = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
#assert diff < 0.0075
#print(sum(abserr)/len(abserr))
#print(sum(relerr)/len(relerr))
abserr = []
relerr = []
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
#assert diff < 0.0075
#print(sum(abserr)/len(abserr))
#print(sum(relerr)/len(relerr))
abserr = []
relerr = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
#assert diff < 0.0075
#print(3, sum(abserr)/len(abserr))
#print(3, sum(relerr)/len(relerr))
def test_few_bit_quant():
#print('')
for bits in range(2, 9):
#print('='*30, bits, '='*30)
for method in ['linear', 'fp8', 'dynamic', 'quantile']:
abserrs = []
relerrs = []
code = None
if method == 'linear':
code = F.create_linear_map(True, total_bits=bits).cuda()
elif method == 'fp8':
ebits = math.ceil(bits/2)
pbits = bits-ebits-1
code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
elif method == 'dynamic':
code = F.create_dynamic_map(True, bits-0, bits).cuda()
elif method == 'quantile':
values = torch.randn(2048, 2048, device='cuda')
code = F.create_quantile_map(values, bits).cuda()
# for some data types we have no zero
# for some data types we have one zero
# for some data types we have two zeros
assert torch.unique(code).numel() in [2**bits, 2**bits-1], f'bits: {bits}, method: {method}'
#print(method, (code==0).sum())
assert code.numel() == 256
for i in range(10):
values = torch.randn(1, 32, device='cuda')
values /= values.abs().max()
#values[values.abs() < 1e-6] += 1e-5
q1 = []
v1 = []
for v in values[0]:
idx = torch.abs(v-code).argmin()
q1.append(idx.item())
v1.append(code[idx].item())
q1 = torch.Tensor(q1).cuda()
v1 = torch.Tensor(v1).cuda()
q2, S2 = F.quantize_blockwise(values, code=code)
v2 = F.dequantize_blockwise(q2, S2)
idx = torch.isclose(q1.int(), q2.int())
err2 = torch.abs(v2-values)
abserrs.append(err2.mean().item())
relerrs.append((err2/(1e-10+values).abs()).mean().item())
if idx.sum():
# some weird cases
err1 = torch.abs(v1-values).mean()
#assert err2.mean() <= err1
else:
torch.testing.assert_allclose(q1, q2)
#print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
#assert False
def test_kbit_quantile_estimation():
for i in range(100):
data = torch.randn(1024, 1024, device='cuda')
for bits in range(2, 9):
p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits)
val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
err = torch.abs(val1-val2).mean()
assert err < 0.038
for i in range(100):
data = torch.randn(1024, 1024, device='cuda')
for bits in range(2, 4):
total_values = 2**bits-1
p = np.linspace(0, 1, 2*total_values+1)
idx = np.arange(1, 2*total_values+1, 2)
p = p[idx]
offset = 1/(2*total_values)
p = np.linspace(offset, 1-offset, total_values)
val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, num_quantiles=2**bits-1)
err = torch.abs(val1-val2).mean()
assert err < 0.035
def test_bench_dequantization():
a = torch.rand(1024, 1024, device='cuda').half()
qa, SA = F.quantize_blockwise(a)
max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000
#print(max_theoretical_mu)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
F.dequantize_blockwise(qa, SA, blocksize=2048)
torch.cuda.synchronize()
#print((time.time()-t0)/1e6)

View File

@ -7,7 +7,7 @@ from torch import nn
import bitsandbytes as bnb
class MockArgs(object):
class MockArgs:
def __init__(self, initial_data):
for key in initial_data:
setattr(self, key, initial_data[key])
@ -15,7 +15,7 @@ class MockArgs(object):
class MLP8bit(torch.nn.Module):
def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
super(MLP8bit, self).__init__()
super().__init__()
self.fc1 = bnb.nn.Linear8bitLt(
dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
threshold=threshold
@ -289,7 +289,7 @@ class LinearFunction(torch.autograd.Function):
class Linear8bit(nn.Module):
def __init__(self, input_features, output_features, bias=True, args=None):
super(Linear8bit, self).__init__()
super().__init__()
self.input_features = input_features
self.output_features = output_features
self.args = args
@ -312,7 +312,7 @@ class Linear8bit(nn.Module):
threshold = [0.0, 3.0]
values = threshold
names = ["threshold_{0}".format(vals) for vals in values]
names = [f"threshold_{vals}" for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names)
@ -378,7 +378,7 @@ def test_linear8bitlt_accumulated_gradient():
threshold = [0.0, 2.0]
values = threshold
names = ["threshold_{0}".format(vals) for vals in values]
names = [f"threshold_{vals}" for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names)

View File

@ -18,7 +18,7 @@ k = 20
def get_temp_dir():
path = "/tmp/autoswap/{0}".format(str(uuid.uuid4()))
path = f"/tmp/autoswap/{str(uuid.uuid4())}"
os.makedirs(path, exist_ok=True)
return path
@ -116,7 +116,7 @@ gtype = [torch.float32, torch.float16]
optimizer_names = ["adam", "momentum", "rmsprop", "lars"]
values = list(product(dim1, dim2, gtype, optimizer_names))
names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
]
@ -187,7 +187,7 @@ dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16]
values = list(product(dim1, dim2, gtype))
names = ["dim1_{0}_dim2_{1}_gtype_{2}".format(*vals) for vals in values]
names = ["dim1_{}_dim2_{}_gtype_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names)
@ -250,7 +250,7 @@ optimizer_names = [
]
values = list(product(dim1, dim2, gtype, optimizer_names))
names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
]
@ -391,7 +391,7 @@ gtype = [torch.float32]
optim_bits = [32, 8]
values = list(product(dim1, dim2, gtype, optim_bits))
names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals)
"dim1_{}_dim2_{}_gtype_{}_optim_bits_{}".format(*vals)
for vals in values
]
@ -495,7 +495,7 @@ gtype = [torch.float32, torch.float16]
optimizer_names = ["adam8bit_blockwise"]
values = list(product(dim1, dim2, gtype, optimizer_names))
names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
]

View File

@ -1,149 +0,0 @@
./setup.py:20:10: F541 f-string is missing placeholders
./setup.py:21:13: F541 f-string is missing placeholders
./quicktest.py:5:1: F401 'bitsandbytes as bnb' imported but unused
./bitsandbytes/cuda_setup.py:42:56: F821 undefined name 'error_str'
./bitsandbytes/cuda_setup.py:43:15: F541 f-string is missing placeholders
./bitsandbytes/cuda_setup.py:67:5: F841 local variable 'context' is assigned to but never used
./bitsandbytes/cuda_setup.py:68:5: F841 local variable 'error_str' is assigned to but never used
./bitsandbytes/cuda_setup.py:76:9: F841 local variable 'result' is assigned to but never used
./bitsandbytes/cuda_setup.py:144:13: F841 local variable 'has_gpu' is assigned to but never used
./bitsandbytes/functional.py:294:13: F821 undefined name 'math'
./bitsandbytes/functional.py:295:16: F821 undefined name 'math'
./bitsandbytes/functional.py:303:5: F841 local variable 'ptrA' is assigned to but never used
./bitsandbytes/functional.py:304:5: F841 local variable 'ptrOut' is assigned to but never used
./bitsandbytes/functional.py:1057:17: W503 line break before binary operator
./bitsandbytes/functional.py:1058:17: W503 line break before binary operator
./bitsandbytes/functional.py:1059:17: W503 line break before binary operator
./bitsandbytes/functional.py:1649:1: F811 redefinition of unused 'get_special_format_str' from line 160
./bitsandbytes/functional.py:1687:5: F841 local variable 'ptrA' is assigned to but never used
./bitsandbytes/functional.py:1688:5: F841 local variable 'ptrOut' is assigned to but never used
./bitsandbytes/functional.py:1802:5: F841 local variable 'ccolsA' is assigned to but never used
./bitsandbytes/functional.py:1805:5: F841 local variable 'cldb' is assigned to but never used
./bitsandbytes/functional.py:1806:5: F841 local variable 'cldc' is assigned to but never used
./bitsandbytes/functional.py:1873:9: F841 local variable 'dtype' is assigned to but never used
./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.MatmulLtState' imported but unused
./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.bmm_cublas' imported but unused
./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.matmul' imported but unused
./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.matmul_cublas' imported but unused
./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.mm_cublas' imported but unused
./bitsandbytes/__init__.py:9:1: F401 '.nn.modules' imported but unused
./bitsandbytes/__init__.py:12:5: F401 '.optim.adam' imported but unused
./bitsandbytes/autograd/_functions.py:5:1: F401 'bitsandbytes as bnb' imported but unused
./bitsandbytes/autograd/_functions.py:12:75: W291 trailing whitespace
./bitsandbytes/nn/__init__.py:5:1: F401 '.modules.Int8Params' imported but unused
./bitsandbytes/nn/__init__.py:5:1: F401 '.modules.Linear8bit' imported but unused
./bitsandbytes/nn/__init__.py:5:1: F401 '.modules.Linear8bitLt' imported but unused
./bitsandbytes/nn/__init__.py:5:1: F401 '.modules.StableEmbedding' imported but unused
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Any' imported but unused
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Callable' imported but unused
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Dict' imported but unused
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Iterator' imported but unused
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Mapping' imported but unused
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Set' imported but unused
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Tuple' imported but unused
./bitsandbytes/nn/modules.py:11:1: F401 'torch.nn.parameter.Parameter' imported but unused
./bitsandbytes/nn/modules.py:183:13: W503 line break before binary operator
./bitsandbytes/nn/modules.py:184:13: W503 line break before binary operator
./bitsandbytes/nn/modules.py:272:24: F821 undefined name 'dist'
./bitsandbytes/nn/modules.py:272:49: F821 undefined name 'dist'
./bitsandbytes/optim/optimizer.py:243:9: F841 local variable 'overflows' is assigned to but never used
./bitsandbytes/optim/optimizer.py:280:35: F541 f-string is missing placeholders
./bitsandbytes/optim/optimizer.py:283:35: F541 f-string is missing placeholders
./bitsandbytes/optim/lars.py:27:39: F541 f-string is missing placeholders
./bitsandbytes/optim/lars.py:59:39: F541 f-string is missing placeholders
./bitsandbytes/optim/lars.py:91:39: F541 f-string is missing placeholders
./bitsandbytes/optim/lars.py:157:13: F841 local variable 'params_with_grad' is assigned to but never used
./bitsandbytes/optim/lars.py:158:13: F841 local variable 'd_p_list' is assigned to but never used
./bitsandbytes/optim/lars.py:159:13: F841 local variable 'momentum_buffer_list' is assigned to but never used
./bitsandbytes/optim/lars.py:174:35: F821 undefined name 'param'
./bitsandbytes/optim/__init__.py:9:5: F401 '.adam.Adam' imported but unused
./bitsandbytes/optim/__init__.py:9:5: F401 '.adam.Adam8bit' imported but unused
./bitsandbytes/optim/__init__.py:9:5: F401 '.adam.Adam32bit' imported but unused
./bitsandbytes/optim/__init__.py:10:5: F401 '.adamw.AdamW' imported but unused
./bitsandbytes/optim/__init__.py:10:5: F401 '.adamw.AdamW8bit' imported but unused
./bitsandbytes/optim/__init__.py:10:5: F401 '.adamw.AdamW32bit' imported but unused
./bitsandbytes/optim/__init__.py:11:5: F401 '.sgd.SGD' imported but unused
./bitsandbytes/optim/__init__.py:11:5: F401 '.sgd.SGD8bit' imported but unused
./bitsandbytes/optim/__init__.py:11:5: F401 '.sgd.SGD32bit' imported but unused
./bitsandbytes/optim/__init__.py:12:5: F401 '.lars.LARS' imported but unused
./bitsandbytes/optim/__init__.py:12:5: F401 '.lars.LARS8bit' imported but unused
./bitsandbytes/optim/__init__.py:12:5: F401 '.lars.LARS32bit' imported but unused
./bitsandbytes/optim/__init__.py:12:5: F401 '.lars.PytorchLARS' imported but unused
./bitsandbytes/optim/__init__.py:13:5: F401 '.lamb.LAMB' imported but unused
./bitsandbytes/optim/__init__.py:13:5: F401 '.lamb.LAMB8bit' imported but unused
./bitsandbytes/optim/__init__.py:13:5: F401 '.lamb.LAMB32bit' imported but unused
./bitsandbytes/optim/__init__.py:14:5: F401 '.rmsprop.RMSprop' imported but unused
./bitsandbytes/optim/__init__.py:14:5: F401 '.rmsprop.RMSprop8bit' imported but unused
./bitsandbytes/optim/__init__.py:14:5: F401 '.rmsprop.RMSprop32bit' imported but unused
./bitsandbytes/optim/__init__.py:15:5: F401 '.adagrad.Adagrad' imported but unused
./bitsandbytes/optim/__init__.py:15:5: F401 '.adagrad.Adagrad8bit' imported but unused
./bitsandbytes/optim/__init__.py:15:5: F401 '.adagrad.Adagrad32bit' imported but unused
./bitsandbytes/optim/__init__.py:17:1: F401 '.optimizer.GlobalOptimManager' imported but unused
./bitsandbytes/optim/adam.py:229:21: F841 local variable 'max_exp_avg_sq' is assigned to but never used
./bitsandbytes/optim/rmsprop.py:25:39: F541 f-string is missing placeholders
./bitsandbytes/optim/rmsprop.py:27:39: F541 f-string is missing placeholders
./bitsandbytes/optim/rmsprop.py:59:39: F541 f-string is missing placeholders
./bitsandbytes/optim/rmsprop.py:61:39: F541 f-string is missing placeholders
./bitsandbytes/optim/rmsprop.py:94:39: F541 f-string is missing placeholders
./bitsandbytes/optim/rmsprop.py:96:39: F541 f-string is missing placeholders
./bitsandbytes/optim/sgd.py:24:39: F541 f-string is missing placeholders
./bitsandbytes/optim/sgd.py:55:39: F541 f-string is missing placeholders
./bitsandbytes/optim/sgd.py:86:39: F541 f-string is missing placeholders
./tests/test_optim.py:1:1: F401 'ctypes' imported but unused
./tests/test_optim.py:199:5: F841 local variable 'mask' is assigned to but never used
./tests/test_optim.py:218:9: F841 local variable 'atol' is assigned to but never used
./tests/test_optim.py:218:15: F841 local variable 'rtol' is assigned to but never used
./tests/test_optim.py:304:17: W503 line break before binary operator
./tests/test_optim.py:354:21: W503 line break before binary operator
./tests/test_autograd.py:309:13: F841 local variable 'err' is assigned to but never used
./tests/test_cuda_setup_evaluator.py:31:9: F821 undefined name 'test_dir'
./tests/test_cuda_setup_evaluator.py:33:14: F821 undefined name 'test_input'
./tests/test_cuda_setup_evaluator.py:81:32: E203 whitespace before ':'
./tests/test_functional.py:55:13: F841 local variable 'ms' is assigned to but never used
./tests/test_functional.py:177:5: F841 local variable 'diffs' is assigned to but never used
./tests/test_functional.py:178:5: F841 local variable 'reldiffs' is assigned to but never used
./tests/test_functional.py:260:5: F841 local variable 'minA' is assigned to but never used
./tests/test_functional.py:261:5: F841 local variable 'maxA' is assigned to but never used
./tests/test_functional.py:584:5: F841 local variable 'func' is assigned to but never used
./tests/test_functional.py:617:17: F841 local variable 'offset' is assigned to but never used
./tests/test_functional.py:618:17: F841 local variable 'col2' is assigned to but never used
./tests/test_functional.py:619:17: F841 local variable 'row2' is assigned to but never used
./tests/test_functional.py:705:9: F841 local variable 'C1' is assigned to but never used
./tests/test_functional.py:706:9: F841 local variable 'C2' is assigned to but never used
./tests/test_functional.py:715:9: F841 local variable 'output' is assigned to but never used
./tests/test_functional.py:750:5: F841 local variable 'formatB' is assigned to but never used
./tests/test_functional.py:754:5: F841 local variable 'w2' is assigned to but never used
./tests/test_functional.py:763:5: F841 local variable 'dtype' is assigned to but never used
./tests/test_functional.py:770:9: F841 local variable 'out1' is assigned to but never used
./tests/test_functional.py:1108:5: F841 local variable 'relerr1' is assigned to but never used
./tests/test_functional.py:1108:14: F841 local variable 'relerr2' is assigned to but never used
./tests/test_functional.py:1114:9: F841 local variable 'C1' is assigned to but never used
./tests/test_functional.py:1135:9: F841 local variable 'C4' is assigned to but never used
./tests/test_functional.py:1179:5: F841 local variable 'err1' is assigned to but never used
./tests/test_functional.py:1179:11: F841 local variable 'err2' is assigned to but never used
./tests/test_functional.py:1179:17: F841 local variable 'err3' is assigned to but never used
./tests/test_functional.py:1180:5: F841 local variable 'relerr1' is assigned to but never used
./tests/test_functional.py:1180:14: F841 local variable 'relerr2' is assigned to but never used
./tests/test_functional.py:1192:9: F841 local variable 'C1' is assigned to but never used
./tests/test_functional.py:1313:9: F841 local variable 'c' is assigned to but never used
./tests/test_functional.py:1314:9: F841 local variable 'c2' is assigned to but never used
./tests/test_functional.py:1406:9: F841 local variable 'C1' is assigned to but never used
./tests/test_functional.py:1425:9: F841 local variable 'out2' is assigned to but never used
./tests/test_functional.py:1542:5: F841 local variable 'idx_col' is assigned to but never used
./tests/test_functional.py:1566:30: E203 whitespace before ':'
./tests/test_functional.py:1568:38: E203 whitespace before ':'
./tests/test_functional.py:1655:5: F841 local variable 'offset' is assigned to but never used
./tests/test_functional.py:1706:9: F841 local variable 'out' is assigned to but never used
./tests/test_functional.py:1822:9: F841 local variable 'out' is assigned to but never used
./tests/test_functional.py:1882:5: F841 local variable 'out2' is assigned to but never used
./tests/test_functional.py:1928:9: F841 local variable 'dtype' is assigned to but never used
./tests/test_functional.py:1934:9: F841 local variable 'minx' is assigned to but never used
./tests/test_functional.py:1948:5: F841 local variable 'C0' is assigned to but never used
./tests/test_modules.py:1:1: F401 'itertools.product' imported but unused
./tests/test_modules.py:52:9: F841 local variable 'norm' is assigned to but never used
./tests/test_modules.py:52:16: F821 undefined name 'math'
./tests/test_modules.py:52:26: F821 undefined name 'math'
./tests/test_modules.py:52:37: F821 undefined name 'math'
./tests/test_modules.py:177:21: F821 undefined name 'einops'
./tests/test_modules.py:233:9: F841 local variable 'stochastic' is assigned to but never used
./tests/test_modules.py:382:9: F841 local variable 'o1' is assigned to but never used