Merge branch 'main' into fp8_merge

This commit is contained in:
Tim Dettmers 2023-04-12 11:44:39 -07:00
commit 7140c01405
29 changed files with 899 additions and 260 deletions

View File

@ -189,3 +189,35 @@ Improvements:
- StableEmbedding layer now has device and dtype parameters to make it 1:1 replaceable with regular Embedding layers (@lostmsu)
- runtime performance of block-wise quantization slightly improved
- added error message for the case multiple libcudart.so are installed and bitsandbytes picks the wrong one
### 0.37.0
#### Int8 Matmul + backward support for all GPUs
Features:
- Int8 MatmulLt now supports backward through inversion of the ColTuring/ColAmpere format. Slow, but memory efficient. Big thanks to @borzunov
- Int8 now supported on all GPUs. On devices with compute capability < 7.5, the Int weights are cast to 16/32-bit for the matrix multiplication. Contributed by @borzunov
Improvements:
- Improved logging for the CUDA detection mechanism.
### 0.38.0
#### 8-bit Lion, Load/Store 8-bit Models directly from/to HF Hub
Features:
- Support for 32 and 8-bit Lion has been added. Thank you @lucidrains
- Support for serialization of Linear8bitLt layers (LLM.int8()). This allows to store and load 8-bit weights directly from the HuggingFace Hub. Thank you @myrab
- New bug report features `python -m bitsandbytes` now gives extensive debugging details to debug CUDA setup failures.
Bug fixes:
- Fixed a bug where some bitsandbytes methods failed in a model-parallel setup on multiple GPUs. Thank you @tonylins
- Fixed a bug where cudart.so libraries could not be found in newer PyTorch releases.
Improvements:
- Improved the CUDA Setup procedure by doing a more extensive search for CUDA libraries
Deprecated:
- Devices with compute capability 3.0 (GTX 700s, K10) and 3.2 (Tegra K1, Jetson TK1) are now deprecated and support will be removed in 0.39.0.
- Support for CUDA 10.0 and 10.2 will be removed in bitsandbytes 0.39.0

View File

@ -60,8 +60,8 @@ CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env

View File

@ -11,10 +11,41 @@ Resources:
## TL;DR
**Requirements**
Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0. LLM.int8() requires Turing or Ampere GPUs.
Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0.
(Deprecated: CUDA 10.0 is deprecated and only CUDA >= 11.0) will be supported with release 0.39.0)
**Installation**:
``pip install bitsandbytes``
In some cases it can happen that you need to compile from source. If this happens please consider submitting a bug report with `python -m bitsandbytes` information. What now follows is some short instructions which might work out of the box if `nvcc` is installed. If these do not work see further below.
Compilation quickstart:
```bash
git clone https://github.com/timdettmers/bitsandbytes.git
cd bitsandbytes
# CUDA_VERSIONS in {110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 120}
# make argument in {cuda110, cuda11x, cuda12x}
# if you do not know what CUDA you have, try looking at the output of: python -m bitsandbytes
CUDA_VERSION=117 make cuda11x
python setup.py install
```
**Using Int8 inference with HuggingFace Transformers**
```python
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
'decapoda-research/llama-7b-hf,
device_map='auto',
load_in_8bit=True,
max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB')
```
A more detailed example, can be found in [examples/int8_inference_huggingface.py](examples/int8_inference_huggingface.py).
**Using 8-bit optimizer**:
1. Comment out optimizer: ``#torch.optim.Adam(....)``
2. Add 8-bit optimizer of your choice ``bnb.optim.Adam8bit(....)`` (arguments stay the same)
@ -39,7 +70,7 @@ out = linear(x.to(torch.float16))
## Features
- 8-bit Matrix multiplication with mixed precision decomposition
- LLM.int8() inference
- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB (saves 75% memory)
- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB, Lion (saves 75% memory)
- Stable Embedding Layer: Improved stability through better initialization, and normalization
- 8-bit quantization: Quantile, Linear, and Dynamic quantization
- Fast quantile estimation: Up to 100x faster than other algorithms
@ -58,6 +89,10 @@ The bitsandbytes library is currently only supported on Linux distributions. Win
The requirements can best be fulfilled by installing pytorch via anaconda. You can install PyTorch by following the ["Get Started"](https://pytorch.org/get-started/locally/) instructions on the official website.
To install run:
``pip install bitsandbytes``
## Using bitsandbytes
### Using Int8 Matrix Multiplication
@ -108,8 +143,23 @@ For upcoming features and changes and full history see [Patch Notes](CHANGELOG.m
2. __fatbinwrap_.. [Solution](errors_and_solutions.md#fatbinwrap_)
## Compile from source
To compile from source, you need an installation of CUDA. If `nvcc` is not installed, you can install the CUDA Toolkit with nvcc through the following commands.
To compile from source, please follow the [compile_from_source.md](compile_from_source.md) instructions.
```bash
wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh
# Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH
# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121}
# EXPORT_TO_BASH in {0, 1} with 0=False and 1=True
# For example, the following installs CUDA 11.8 to ~/local/cuda-11.8 and exports the path to your .bashrc
bash cuda install 118 ~/local 1
```
To use a specific CUDA version just for a single compile run, you can set the variable `CUDA_HOME`, for example the following command compiles `libbitsandbytes_cuda117.so` using compiler flags for cuda11x with the cuda version at `~/local/cuda-11.7`:
``CUDA_HOME=~/local/cuda-11.7 CUDA_VERSION=117 make cuda11x``
For more detailed instruction, please follow the [compile_from_source.md](compile_from_source.md) instructions.
## License

View File

@ -1,11 +1,82 @@
import os
import sys
import shlex
import subprocess
from warnings import warn
from typing import Tuple
from os.path import isdir
import torch
HEADER_WIDTH = 60
def execute_and_return(command_string: str) -> Tuple[str, str]:
def _decode(subprocess_err_out_tuple):
return tuple(
to_decode.decode("UTF-8").strip()
for to_decode in subprocess_err_out_tuple
)
def execute_and_return_decoded_std_streams(command_string):
return _decode(
subprocess.Popen(
shlex.split(command_string),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
).communicate()
)
std_out, std_err = execute_and_return_decoded_std_streams(command_string)
return std_out, std_err
def find_file_recursive(folder, filename):
cmd = f'find {folder} -name {filename}'
out, err = execute_and_return(cmd)
if len(err) > 0:
raise RuntimeError('Something when wrong when trying to find file. Maybe you do not have a linux system?')
return out
def generate_bug_report_information():
print_header("")
print_header("BUG REPORT INFORMATION")
print_header("")
print('')
if 'CONDA_PREFIX' in os.environ:
paths = find_file_recursive(os.environ['CONDA_PREFIX'], '*cuda*so')
print_header("ANACONDA CUDA PATHS")
print(paths)
print('')
if isdir('/usr/local/'):
paths = find_file_recursive('/usr/local', '*cuda*so')
print_header("/usr/local CUDA PATHS")
print(paths)
print('')
if isdir(os.getcwd()):
paths = find_file_recursive(os.getcwd(), '*cuda*so')
print_header("WORKING DIRECTORY CUDA PATHS")
print(paths)
print('')
print_header("LD_LIBRARY CUDA PATHS")
lib_path = os.environ['LD_LIBRARY_PATH'].strip()
for path in set(lib_path.split(':')):
try:
if isdir(path):
print_header(f"{path} CUDA PATHS")
paths = find_file_recursive(path, '*cuda*so')
print(paths)
except:
print(f'Could not read LD_LIBRARY_PATH: {path}')
print('')
def print_header(
txt: str, width: int = HEADER_WIDTH, filler: str = "+"
@ -21,28 +92,16 @@ def print_debug_info() -> None:
)
print_header("")
print_header("DEBUG INFORMATION")
print_header("")
print()
generate_bug_report_information()
from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL
from .cuda_setup.env_vars import to_be_ignored
from .cuda_setup.main import get_compute_capabilities, get_cuda_lib_handle
print_header("POTENTIALLY LIBRARY-PATH-LIKE ENV VARS")
for k, v in os.environ.items():
if "/" in v and not to_be_ignored(k, v):
print(f"'{k}': '{v}'")
print_header("")
print(
"\nWARNING: Please be sure to sanitize sensible info from any such env vars!\n"
)
print_header("OTHER")
print(f"{COMPILED_WITH_CUDA = }")
print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}")
cuda = get_cuda_lib_handle()
print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities(cuda)}")
print_header("")
@ -55,6 +114,7 @@ Running a quick check that:
+ CUDA function is callable
"""
)
print("\nWARNING: Please be sure to sanitize sensible info from any such env vars!\n")
try:
from bitsandbytes.optim import Adam
@ -91,3 +151,4 @@ except Exception as e:
print(e)
print_debug_info()
sys.exit(1)

View File

@ -0,0 +1 @@
from ._functions import undo_layout, get_inverse_transform_indices

View File

@ -2,6 +2,7 @@ import operator
import warnings
from dataclasses import dataclass
from functools import reduce # Required in Python 3
from typing import Tuple, Optional
import torch
@ -14,6 +15,12 @@ def prod(iterable):
tensor = torch.Tensor
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
"""
This class pools outlier dimensions across layers.
This is particularly important for small models where outlier features
@ -48,6 +55,51 @@ class GlobalOutlierPooler:
return torch.Tensor(list(self.outliers)).to(torch.int64)
def get_inverse_transform_indices(transform_tile: callable, tile_size: Tuple[int, int]):
"""
Compute a permutation of indices that invert the specified (tiled) matrix transformation
:param transform_tile: a function that applies forward transform to a tensor of shape [dim1, dim2]
:param tile_size: higher-level tile dimensions, i.e. (8, 32) for Turing and (32, 32) for Ampere
:note: we assume that tile_transform applies to a cpu-based int8 tensor of shape tile_size
:example: transform_tile function for the turing layout (bitsandbytes.functional as F)
:returns: indices
"""
d1, d2 = tile_size
assert 0 < d1 * d2 < 2**64
tile_indices = torch.arange(d1 * d2, dtype=torch.int64).view(d1, d2)
# encode each position in tile as a tuple of <= 8 unique bytes
permuted_tile_indices = torch.zeros_like(tile_indices)
for i in range(8):
# select i-th byte, apply transformation and trace where each index ended up
ith_dim_indices = torch.div(tile_indices, 256**i, rounding_mode="trunc") % 256
sample_tile_i = (ith_dim_indices - 128).to(torch.int8).contiguous()
assert torch.all(sample_tile_i.int() + 128 == ith_dim_indices), "int overflow"
permuted_tile_i = transform_tile(sample_tile_i)
ith_permuted_indices = permuted_tile_i.to(tile_indices.dtype) + 128
permuted_tile_indices += ith_permuted_indices * (256**i)
if d1 * d2 < 256**i:
break # if all indices fit in i bytes, stop early
return permuted_tile_indices
def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor:
"""
Undo a tiled permutation such as turing or ampere layout
:param permuted_tensor: torch tensor in a permuted layout
:param tile_indices: reverse transformation indices, from get_inverse_transform_indices
:return: contiguous row-major tensor
"""
(rows, cols), (tile_rows, tile_cols) = permuted_tensor.shape, tile_indices.shape
assert rows % tile_rows == cols % tile_cols == 0, "tensor must contain a whole number of tiles"
tensor = permuted_tensor.reshape(-1, tile_indices.numel()).t()
outputs = torch.empty_like(tensor) # note: not using .index_copy because it was slower on cuda
outputs[tile_indices.flatten()] = tensor
outputs = outputs.reshape(tile_rows, tile_cols, cols // tile_cols, rows // tile_rows)
outputs = outputs.permute(3, 0, 2, 1) # (rows // tile_rows, tile_rows), (cols // tile_cols, tile_cols)
return outputs.reshape(rows, cols).contiguous()
class MatMul8bit(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, quant_type="vector", precision=None):
@ -169,8 +221,21 @@ bmm_cublas = MatMul8bit.apply
matmul_cublas = MatMul8bit.apply
def supports_igemmlt(device: torch.device) -> bool:
"""check if this device supports the optimized int8 kernel"""
if torch.cuda.get_device_capability(device=device) < (7, 5):
return False
device_name = torch.cuda.get_device_name(device=device)
nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series
if any(model_name in device_name for model_name in nvidia16_models):
return False # these devices are technically cuda 7.5-capable, but they lack tensor cores
return True
@dataclass
class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None
force_no_igemmlt: bool = False
CB = None
CxB = None
SB = None
@ -202,11 +267,31 @@ class MatmulLtState:
self.SBt = None
self.CBt = None
def get_tile_size(self):
assert self.formatB in (
"col_turing",
"col_ampere",
), f"please find this assert and manually enter tile size for {self.formatB}"
return (8, 32) if self.formatB == "col_turing" else (32, 32)
@property
def tile_indices(self):
if self._tile_indices is None:
device = self.CxB.device
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=self.formatB)[0].to(x.device)
with torch.no_grad():
self._tile_indices = get_inverse_transform_indices(transform, self.get_tile_size()).to(device)
return self._tile_indices
class MatMul8bitLt(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@staticmethod
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
# default to pytorch behavior if inputs are empty
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
using_igemmlt = supports_igemmlt(A.device) and not state.force_no_igemmlt
# default of pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
ctx.is_empty = True
@ -214,9 +299,9 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.B = B
ctx.bias = bias
if A.shape[-1] == B.shape[0]:
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device)
return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=A.device)
else:
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device)
return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device)
# 1. Quantize A
# 2. Quantize B
@ -235,9 +320,7 @@ class MatMul8bitLt(torch.autograd.Function):
# 1. Quantize A
if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous()
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
A.to(torch.float16), threshold=state.threshold
)
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
if state.threshold > 0.0 and coo_tensorA is not None:
if state.has_fp16_weights:
@ -248,12 +331,12 @@ class MatMul8bitLt(torch.autograd.Function):
state.subB = B[:, idx].t().contiguous()
state.idx = idx
else:
if state.CxB is None:
if state.CxB is None and using_igemmlt:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
else:
if not state.has_fp16_weights and state.CxB is None:
if not state.has_fp16_weights and state.CxB is None and using_igemmlt:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
subA = None
@ -273,7 +356,10 @@ class MatMul8bitLt(torch.autograd.Function):
state.SCBt,
coo_tensorB,
) = F.double_quant(B.to(torch.float16))
state.CxB, state.SB = F.transform(CB, to_order=formatB)
if using_igemmlt:
state.CxB, state.SB = F.transform(CB, to_order=formatB)
else:
state.CB = CB
else:
has_grad = False
@ -288,18 +374,17 @@ class MatMul8bitLt(torch.autograd.Function):
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# else:
# state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
state.subB = (
(outliers * state.SCB.view(-1, 1) / 127.0)
.t()
.contiguous()
.to(A.dtype)
)
if state.CxB is not None:
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
else:
outliers = state.CB[:, state.idx.long()].clone()
state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype)
CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0
subA = A[:, state.idx.long()]
shapeB = state.SB[0]
shapeB = state.SB[0] if state.SB else B.shape
if len(input_shape) == 3:
output_shape = (input_shape[0], input_shape[1], shapeB[0])
@ -307,16 +392,25 @@ class MatMul8bitLt(torch.autograd.Function):
output_shape = (input_shape[0], shapeB[0])
# 3. Matmul
C32A, SA = F.transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
# we apply the fused bias here
if using_igemmlt:
C32A, SA = F.transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
if bias is None or bias.dtype == torch.float16:
# we apply the fused bias here
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
output = output.to(A.dtype)
else: # apply bias separately
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
output = output.to(A.dtype).add_(bias)
if bias is None or bias.dtype == torch.float16:
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
output = output.to(A.dtype)
else: # apply bias separately
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
output = output.to(A.dtype).add_(bias)
else:
A_wo_outliers = A.clone()
if state.idx is not None:
A_wo_outliers[:, state.idx.long()] = 0
output = torch.nn.functional.linear(A_wo_outliers, state.CB.to(A.dtype))
output = output.mul_(state.SCB.unsqueeze(0).mul(1.0 / 127.0))
if bias is not None:
output = output.add_(bias)
# 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None:
@ -337,14 +431,13 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
clone_func = torch.clone if len(output_shape) == 3 else lambda x: x
return clone_func(output.view(output_shape))
@staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA = ctx.tensors
@ -359,9 +452,7 @@ class MatMul8bitLt(torch.autograd.Function):
# Cast grad_output to fp16
if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(
-1, grad_output.shape[-1]
).contiguous()
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
if req_gradB:
@ -376,17 +467,22 @@ class MatMul8bitLt(torch.autograd.Function):
if state.CBt is not None:
C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None:
state.CxBt, state.SBt = F.transform(
state.CBt, to_order=formatB, transpose=True
)
state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0))
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CxB is not None:
CB = (
undo_layout(state.CxB, state.tile_indices)
.to(ctx.dtype_A)
.mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
)
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
else:
raise Exception('State must contain either CBt or CB matrix for backward')
raise Exception("State must contain either CBt or CB or CxB matrix for backward")
return grad_A, grad_B, None, grad_bias, None

View File

@ -11,8 +11,6 @@ from bitsandbytes.cuda_setup.main import CUDASetup
setup = CUDASetup.get_instance()
if setup.initialized != True:
setup.run_cuda_setup()
if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
setup.print_log_stack()
lib = setup.lib
try:
@ -20,14 +18,22 @@ try:
CUDASetup.get_instance().generate_instructions()
CUDASetup.get_instance().print_log_stack()
raise RuntimeError('''
CUDA Setup failed despite GPU being available. Inspect the CUDA SETUP outputs above to fix your environment!
If you cannot find any issues and suspect a bug, please open an issue with detals about your environment:
https://github.com/TimDettmers/bitsandbytes/issues''')
CUDA Setup failed despite GPU being available. Please run the following command to get more information:
python -m bitsandbytes
Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''')
lib.cadam32bit_g32
lib.get_context.restype = ct.c_void_p
lib.get_cusparse.restype = ct.c_void_p
COMPILED_WITH_CUDA = True
except AttributeError:
warn("The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers and GPU quantization are unavailable.")
"8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.")
COMPILED_WITH_CUDA = False
# print the setup details after checking for errors so we do not print twice
if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
setup.print_log_stack()

View File

@ -11,6 +11,7 @@ def to_be_ignored(env_var: str, value: str) -> bool:
"HOME", # Linux shell default
"TMUX", # Terminal Multiplexer
"XDG_DATA_DIRS", # XDG: Desktop environment stuff
"XDG_GREETER_DATA_DIR", # XDG: Desktop environment stuff
"XDG_RUNTIME_DIR",
"MAIL", # something related to emails
"SHELL", # binary for currently invoked shell

View File

@ -21,12 +21,21 @@ import os
import errno
import torch
from warnings import warn
from itertools import product
from pathlib import Path
from typing import Set, Union
from .env_vars import get_potentially_lib_path_containing_env_vars
CUDA_RUNTIME_LIB: str = "libcudart.so"
# these are the most common libs names
# libcudart.so is missing by default for a conda install with PyTorch 2.0 and instead
# we have libcudart.so.11.0 which causes a lot of errors before
# not sure if libcudart.so.12.0 exists in pytorch installs, but it does not hurt
CUDA_RUNTIME_LIBS: list = ["libcudart.so", 'libcudart.so.11.0', 'libcudart.so.12.0']
# this is a order list of backup paths to search CUDA in, if it cannot be found in the main environmental paths
backup_paths = []
backup_paths.append('$CONDA_PREFIX/lib/libcudart.so.11.0')
class CUDASetup:
_instance = None
@ -80,9 +89,10 @@ class CUDASetup:
self.add_log_entry('python setup.py install')
def initialize(self):
self.has_printed = False
self.lib = None
self.initialized = False
if not getattr(self, 'initialized', False):
self.has_printed = False
self.lib = None
self.initialized = False
def run_cuda_setup(self):
self.initialized = True
@ -97,13 +107,15 @@ class CUDASetup:
package_dir = Path(__file__).parent.parent
binary_path = package_dir / binary_name
print('bin', binary_path)
try:
if not binary_path.exists():
self.add_log_entry(f"CUDA SETUP: Required library version not found: {binary_name}. Maybe you need to compile it from source?")
legacy_binary_name = "libbitsandbytes_cpu.so"
self.add_log_entry(f"CUDA SETUP: Defaulting to {legacy_binary_name}...")
binary_path = package_dir / legacy_binary_name
if not binary_path.exists():
if not binary_path.exists() or torch.cuda.is_available():
self.add_log_entry('')
self.add_log_entry('='*48 + 'ERROR' + '='*37)
self.add_log_entry('CUDA SETUP: CUDA detection failed! Possible reasons:')
@ -112,10 +124,10 @@ class CUDASetup:
self.add_log_entry('3. You have multiple conflicting CUDA libraries')
self.add_log_entry('4. Required library not pre-compiled for this bitsandbytes release!')
self.add_log_entry('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.')
self.add_log_entry('CUDA SETUP: The CUDA version for the compile might depend on your conda install. Inspect CUDA version via `conda list | grep cuda`.')
self.add_log_entry('='*80)
self.add_log_entry('')
self.generate_instructions()
self.print_log_stack()
raise Exception('CUDA SETUP: Setup Failed!')
self.lib = ct.cdll.LoadLibrary(binary_path)
else:
@ -123,7 +135,6 @@ class CUDASetup:
self.lib = ct.cdll.LoadLibrary(binary_path)
except Exception as ex:
self.add_log_entry(str(ex))
self.print_log_stack()
def add_log_entry(self, msg, is_warning=False):
self.cuda_setup_log.append((msg, is_warning))
@ -148,7 +159,7 @@ def is_cublasLt_compatible(cc):
if cc is not None:
cc_major, cc_minor = cc.split('.')
if int(cc_major) < 7 or (int(cc_major) == 7 and int(cc_minor) < 5):
cuda_setup.add_log_entry("WARNING: Compute capability < 7.5 detected! Proceeding to load CPU-only library...", is_warning=True)
CUDASetup.get_instance().add_log_entry("WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU!", is_warning=True)
else:
has_cublaslt = True
return has_cublaslt
@ -176,11 +187,12 @@ def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]:
def get_cuda_runtime_lib_paths(candidate_paths: Set[Path]) -> Set[Path]:
return {
path / CUDA_RUNTIME_LIB
for path in candidate_paths
if (path / CUDA_RUNTIME_LIB).is_file()
}
paths = set()
for libname in CUDA_RUNTIME_LIBS:
for path in candidate_paths:
if (path / libname).is_file():
paths.add(path / libname)
return paths
def resolve_paths_list(paths_list_candidate: str) -> Set[Path]:
@ -200,12 +212,12 @@ def find_cuda_lib_in(paths_list_candidate: str) -> Set[Path]:
def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None:
if len(results_paths) > 1:
warning_msg = (
f"Found duplicate {CUDA_RUNTIME_LIB} files: {results_paths}.. "
f"Found duplicate {CUDA_RUNTIME_LIBS} files: {results_paths}.. "
"We'll flip a coin and try one of these, in order to fail forward.\n"
"Either way, this might cause trouble in the future:\n"
"If you get `CUDA error: invalid device function` errors, the above "
"might be the cause and the solution is to make sure only one "
f"{CUDA_RUNTIME_LIB} in the paths that we search based on your env.")
f"{CUDA_RUNTIME_LIBS} in the paths that we search based on your env.")
CUDASetup.get_instance().add_log_entry(warning_msg, is_warning=True)
@ -233,7 +245,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
return next(iter(conda_cuda_libs))
CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["CONDA_PREFIX"]} did not contain '
f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...', is_warning=True)
f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', is_warning=True)
if "LD_LIBRARY_PATH" in candidate_env_vars:
lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"])
@ -243,7 +255,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
warn_in_case_of_duplicates(lib_ld_cuda_libs)
CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain '
f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...', is_warning=True)
f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', is_warning=True)
remaining_candidate_env_vars = {
env_var: value for env_var, value in candidate_env_vars.items()
@ -255,7 +267,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
cuda_runtime_libs.update(find_cuda_lib_in(value))
if len(cuda_runtime_libs) == 0:
CUDASetup.get_instance().add_log_entry('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching /usr/local/cuda/lib64...')
CUDASetup.get_instance().add_log_entry('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching in backup paths...')
cuda_runtime_libs.update(find_cuda_lib_in('/usr/local/cuda/lib64'))
warn_in_case_of_duplicates(cuda_runtime_libs)
@ -361,10 +373,10 @@ def evaluate_cuda_setup():
if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
print('')
print('='*35 + 'BUG REPORT' + '='*35)
print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
print(('Welcome to bitsandbytes. For bug reports, please run\n\npython -m bitsandbytes\n\n'),
('and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues'))
print('='*80)
if not torch.cuda.is_available(): return 'libsbitsandbytes_cpu.so', None, None, None, None
if not torch.cuda.is_available(): return 'libbitsandbytes_cpu.so', None, None, None, None
cuda_setup = CUDASetup.get_instance()
cudart_path = determine_cuda_runtime_lib_path()

View File

@ -35,6 +35,10 @@ if COMPILED_WITH_CUDA:
lib.crmsprop32bit_g32,
lib.crmsprop32bit_g16,
)
str2optimizer32bit["lion"] = (
lib.clion32bit_g32,
lib.clion32bit_g16,
)
str2optimizer32bit["adagrad"] = (
lib.cadagrad32bit_g32,
lib.cadagrad32bit_g16,
@ -58,6 +62,10 @@ if COMPILED_WITH_CUDA:
lib.crmsprop_static_8bit_g32,
lib.crmsprop_static_8bit_g16,
)
str2optimizer8bit["lion"] = (
lib.clion_static_8bit_g32,
lib.clion_static_8bit_g16,
)
str2optimizer8bit["lamb"] = (
lib.cadam_static_8bit_g32,
lib.cadam_static_8bit_g16,
@ -80,6 +88,10 @@ if COMPILED_WITH_CUDA:
lib.crmsprop_8bit_blockwise_fp32,
lib.crmsprop_8bit_blockwise_fp16,
)
str2optimizer8bit_blockwise["lion"] = (
lib.clion_8bit_blockwise_fp32,
lib.clion_8bit_blockwise_fp16,
)
str2optimizer8bit_blockwise["adagrad"] = (
lib.cadagrad_8bit_blockwise_fp32,
lib.cadagrad_8bit_blockwise_fp16,
@ -655,9 +667,11 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
torch.Tensor:
Quantized 8-bit tensor.
'''
prev_device = pre_call(A.device)
if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
is_on_gpu([A, out])
lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
post_call(prev_device)
return out
@ -682,9 +696,11 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
torch.Tensor:
32-bit output tensor.
'''
prev_device = pre_call(A.device)
if out is None: out = torch.zeros_like(A, dtype=torch.float32)
is_on_gpu([code, A, out])
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
post_call(prev_device)
return out
@ -753,6 +769,8 @@ def optimizer_update_32bit(
f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}'
)
prev_device = pre_call(g.device)
is_on_gpu([g, p, state1, state2, unorm_vec])
if g.dtype == torch.float32 and state1.dtype == torch.float32:
str2optimizer32bit[optimizer_name][0](
get_ptr(g),
@ -795,6 +813,7 @@ def optimizer_update_32bit(
raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
)
post_call(prev_device)
def optimizer_update_8bit(
@ -873,6 +892,8 @@ def optimizer_update_8bit(
if max_unorm > 0.0:
param_norm = torch.norm(p.data.float())
prev_device = pre_call(g.device)
is_on_gpu([g, p, state1, state2, unorm_vec, qmap1, qmap2, max1, max2, new_max1, new_max2])
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
str2optimizer8bit[optimizer_name][0](
get_ptr(p),
@ -925,6 +946,7 @@ def optimizer_update_8bit(
raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
)
post_call(prev_device)
def optimizer_update_8bit_blockwise(
@ -947,6 +969,8 @@ def optimizer_update_8bit_blockwise(
skip_zeros=False,
) -> None:
prev_device = pre_call(g.device)
is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2])
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
str2optimizer8bit_blockwise[optimizer_name][0](
get_ptr(p),
@ -991,6 +1015,7 @@ def optimizer_update_8bit_blockwise(
raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
)
post_call(prev_device)
def percentile_clipping(
@ -1006,6 +1031,7 @@ def percentile_clipping(
The current optimiation steps (number of past gradient norms).
"""
prev_device = pre_call(grad.device)
is_on_gpu([grad, gnorm_vec])
if grad.dtype == torch.float32:
lib.cpercentile_clipping_g32(
@ -1023,6 +1049,7 @@ def percentile_clipping(
)
else:
raise ValueError(f"Gradient type {grad.dtype} not supported!")
post_call(prev_device)
current_gnorm = torch.sqrt(gnorm_vec[step % 100])
vals, idx = torch.sort(gnorm_vec)
@ -1779,6 +1806,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
(cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype
)
nnz = cooA.nnz
prev_device = pre_call(B.device)
assert cooA.rowidx.numel() == nnz
assert cooA.colidx.numel() == nnz
assert cooA.values.numel() == nnz
@ -1855,6 +1883,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
ccolsB,
)
# else: assertion error
post_call(prev_device)
return out

View File

@ -9,6 +9,8 @@ import torch.nn.functional as F
from torch import Tensor, device, dtype, nn
import bitsandbytes as bnb
import bitsandbytes.functional
from bitsandbytes.autograd._functions import get_inverse_transform_indices, undo_layout
from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import OutlierTracer, find_outlier_dims
@ -238,19 +240,10 @@ class Int8Params(torch.nn.Parameter):
class Linear8bitLt(nn.Linear):
def __init__(
self,
input_features,
output_features,
bias=True,
has_fp16_weights=True,
memory_efficient_backward=False,
threshold=0.0,
index=None,
):
super().__init__(
input_features, output_features, bias
)
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
memory_efficient_backward=False, threshold=0.0, index=None):
super().__init__(input_features, output_features, bias)
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
self.state = bnb.MatmulLtState()
self.index = index
@ -260,9 +253,54 @@ class Linear8bitLt(nn.Linear):
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True
self.weight = Int8Params(
self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
)
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
def _save_to_state_dict(self, destination, prefix, keep_vars):
if not self.state.has_fp16_weights and self.state.CB is None and self.state.CxB is not None:
# reorder weight layout back from ampere/turing to row
reorder_layout = True
weight_clone = self.weight.data.clone()
else:
reorder_layout = False
try:
if reorder_layout:
self.weight.data = undo_layout(self.state.CxB, self.state.tile_indices)
super()._save_to_state_dict(destination, prefix, keep_vars)
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
weight_name = "SCB"
# case 1: .cuda was called, SCB is in self.weight
param_from_weight = getattr(self.weight, weight_name)
# case 2: self.init_8bit_state was called, SCB is in self.state
param_from_state = getattr(self.state, weight_name)
key_name = prefix + f"{weight_name}"
if param_from_weight is not None:
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
elif not self.state.has_fp16_weights and param_from_state is not None:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
finally:
if reorder_layout:
self.weight.data = weight_clone
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs)
for key in unexpected_keys:
input_name = key[len(prefix):]
if input_name == "SCB":
if self.weight.SCB is None:
# buffers not yet initialized, can't call them directly without
raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is "
"not supported. Please call module.cuda() before module.load_state_dict()")
input_param = state_dict[key]
self.weight.SCB.copy_(input_param)
unexpected_keys.remove(key)
def init_8bit_state(self):
self.state.CB = self.weight.CB
@ -270,30 +308,23 @@ class Linear8bitLt(nn.Linear):
self.weight.CB = None
self.weight.SCB = None
def forward(self, x):
def forward(self, x: torch.Tensor):
self.state.is_training = self.training
if self.weight.CB is not None:
self.init_8bit_state()
# weights are cast automatically as Int8Params, but the bias has to be cast manually
# if self.bias is not None and self.bias.dtype != torch.float16:
# self.bias.data = self.bias.data.half()
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
#out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
if not self.state.has_fp16_weights:
if not self.state.memory_efficient_backward and self.state.CB is not None:
if self.state.CB is not None and self.state.CxB is not None:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del self.state.CB
self.weight.data = self.state.CxB
elif self.state.memory_efficient_backward and self.state.CxB is not None:
# For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
# Thus, we delete CxB from the state.
del self.state.CxB
return out
@ -336,22 +367,4 @@ class SwitchBackLinearBnb(nn.Linear):
if self.weight.CB is not None:
self.init_8bit_state()
# weights are cast automatically as Int8Params, but the bias has to be cast manually
# if self.bias is not None and self.bias.dtype != torch.float16:
# self.bias.data = self.bias.data.half()
#out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
out = bnb.matmul_mixed(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
if not self.state.has_fp16_weights:
if not self.state.memory_efficient_backward and self.state.CB is not None:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del self.state.CB
self.weight.data = self.state.CxB
elif self.state.memory_efficient_backward and self.state.CxB is not None:
# For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
# Thus, we delete CxB from the state.
del self.state.CxB
return out

View File

@ -12,4 +12,5 @@ from .lamb import LAMB, LAMB8bit, LAMB32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .optimizer import GlobalOptimManager
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
from .lion import Lion, Lion8bit, Lion32bit
from .sgd import SGD, SGD8bit, SGD32bit

View File

@ -0,0 +1,87 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State
class Lion(Optimizer1State):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"lion",
params,
lr,
betas,
0.,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class Lion8bit(Optimizer1State):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"lion",
params,
lr,
betas,
0.,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class Lion32bit(Optimizer1State):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"lion",
params,
lr,
betas,
0.,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)

View File

@ -665,7 +665,7 @@ class Optimizer1State(Optimizer8bit):
step,
config["lr"],
None,
0.0,
config['betas'][1],
config["weight_decay"],
gnorm_scale,
state["unorm_vec"] if config["max_unorm"] > 0.0 else None,

View File

@ -1,20 +1,35 @@
# Compiling from source
Basic steps.
1. `make [target]` where `[target]` is among `cuda92, cuda10x, cuda110, cuda11x, cpuonly`
2. `CUDA_VERSION=XXX python setup.py install`
1. `CUDA_VERSION=XXX make [target]` where `[target]` is among `cuda92, cuda10x, cuda110, cuda11x, cuda12x, cpuonly`
2. `python setup.py install`
To run these steps you will need to have the nvcc compiler installed that comes with a CUDA installation. If you use anaconda (recommended) then you can figure out which version of CUDA you are using with PyTorch via the command `conda list | grep cudatoolkit`. Then you can install the nvcc compiler by downloading and installing the same CUDA version from the [CUDA toolkit archive](https://developer.nvidia.com/cuda-toolkit-archive).
For your convenience, there is an installation script in the root directory that installs CUDA 11.1 locally and configures it automatically. After installing you should add the `bin` sub-directory to the `$PATH` variable to make the compiler visible to your system. To do this you can add this to your `.bashrc` by executing these commands:
You can install CUDA locally without sudo by following the following steps:
```bash
echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64/" >> ~/.bashrc
echo "export PATH=$PATH:/usr/local/cuda/bin/" >> ~/.bashrc
source ~/.bashrc
wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh
# Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH
# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121}
# EXPORT_TO_BASH in {0, 1} with 0=False and 1=True
# For example, the following installs CUDA 11.7 to ~/local/cuda-11.7 and exports the path to your .bashrc
bash cuda install 117 ~/local 1
```
By default, the Makefile will look at your `CUDA_HOME` environmental variable to find your CUDA version for compiling the library. If this path is not set it is inferred from the path of your `nvcc` compiler.
Either `nvcc` needs to be in path for the `CUDA_HOME` variable needs to be set to the CUDA directory root (e.g. `/usr/local/cuda`) in order for compilation to succeed
If you type `nvcc` and it cannot be found, you might need to add to your path or set the CUDA_HOME variable. You can run `python -m bitsandbytes` to find the path to CUDA. For example if `python -m bitsandbytes` shows you the following:
```
++++++++++++++++++ /usr/local CUDA PATHS +++++++++++++++++++
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudart.so
```
You can set `CUDA_HOME` to `/usr/local/cuda-11.7`. For example, you might be able to compile like this.
``CUDA_HOME=~/local/cuda-11.7 CUDA_VERSION=117 make cuda11x``
If you have problems compiling the library with these instructions from source, please open an issue.

View File

@ -43,6 +43,14 @@ __device__ float atomicMin(float* address, float val) {
return __int_as_float(old);
}
// sign function for lion
// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA
template <typename T>
__device__ int sgn(T val) {
return (T(0) < val) - (val < T(0));
}
template <int STOCHASTIC>
__device__ unsigned char dQuantize(float* smem_code, const float rand, float x)
{
@ -745,7 +753,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1)
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm,
const float beta1, const float eps, const float weight_decay,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n)
{
@ -792,6 +800,9 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
break;
case LION:
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*(float)g_vals[j]); // state update
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update
s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
@ -823,7 +834,7 @@ template<typename T, int OPTIMIZER>
__launch_bounds__(TH, 1)
__global__ void kOptimizer32bit1State(T *g, T *p,
float *state1, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float eps, const float weight_decay,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
{
@ -892,6 +903,10 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
break;
case LION:
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))));
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j]));
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
@ -1160,7 +1175,7 @@ __global__ void
__launch_bounds__(NUM_THREADS, 2)
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
float *unorm,
const float beta1,
const float beta1, const float beta2,
const float eps, const int step,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
@ -1221,6 +1236,9 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
if(unorm != NULL)
local_unorm += s1_vals[j]*s1_vals[j];
break;
case LION:
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
break;
@ -1244,9 +1262,10 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
template<typename T, int OPTIMIZER>
__global__ void
__launch_bounds__(1024, 1)
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
const float *unorm, const float max_unorm, const float param_norm,
const float beta1,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
@ -1309,8 +1328,19 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
{
g_val = float(g_vals[j]);
g_val *= gnorm_scale;
if(weight_decay > 0.0f)
g_val += ((float)p_vals[j])*weight_decay;
if(weight_decay > 0.0f) {
switch(OPTIMIZER) {
case MOMENTUM:
case RMSPROP:
g_val += ((float)p_vals[j])*weight_decay;
break;
case LION:
p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay);
break;
}
}
s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0];
switch(OPTIMIZER)
@ -1323,6 +1353,10 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j]));
break;
case LION:
p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val))));
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps));
@ -1651,10 +1685,20 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
{
g_val = float(g_vals[j]);
g_val *= gnorm_scale;
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
if(weight_decay > 0.0f)
g_val += ((float)p_vals[j])*weight_decay;
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
if(weight_decay > 0.0f) {
switch(OPTIMIZER) {
case MOMENTUM:
case ADAGRAD:
case RMSPROP:
g_val += ((float)p_vals[j])*weight_decay;
break;
case LION:
p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay);
break;
}
}
s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
@ -1666,6 +1710,11 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
else
s1_vals[j] = (s1_vals[j]*beta1) + g_val;
break;
case LION:
// here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, before the momentum is updated by beta2
g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val));
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
break;
@ -1703,6 +1752,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
case MOMENTUM:
p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
break;
case LION:
p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]);
break;
case RMSPROP:
g_val = g_vals[j];
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
@ -2694,24 +2746,28 @@ template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *c
#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
float* state1, float *unorm, \
const float beta1, const float eps, const float weight_decay, \
const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const int n); \
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half)
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, half)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, float)
MAKE_PreconditionOptimizer32bit1State(LION, half)
MAKE_PreconditionOptimizer32bit1State(LION, float)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
#define MAKE_Optimizer32bit1State(oname, gtype) \
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
const float beta1, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
MAKE_Optimizer32bit1State(MOMENTUM, half)
MAKE_Optimizer32bit1State(MOMENTUM, float)
MAKE_Optimizer32bit1State(RMSPROP, half)
MAKE_Optimizer32bit1State(RMSPROP, float)
MAKE_Optimizer32bit1State(LION, half)
MAKE_Optimizer32bit1State(LION, float)
MAKE_Optimizer32bit1State(ADAGRAD, half)
MAKE_Optimizer32bit1State(ADAGRAD, float)
@ -2733,6 +2789,7 @@ template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p,
template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \
float *unorm, \
const float beta1, \
const float beta2, \
const float eps, const int step, \
float* __restrict__ const quantiles1, \
float* max1, float* new_max1, \
@ -2744,11 +2801,14 @@ MAKE_PreconditionStatic8bit1State(MOMENTUM, half)
MAKE_PreconditionStatic8bit1State(MOMENTUM, float)
MAKE_PreconditionStatic8bit1State(RMSPROP, half)
MAKE_PreconditionStatic8bit1State(RMSPROP, float)
MAKE_PreconditionStatic8bit1State(LION, half)
MAKE_PreconditionStatic8bit1State(LION, float)
#define MAKE_optimizerStatic8bit1State(oname, gtype) \
template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, \
const float *unorm, const float max_unorm, const float param_norm, \
const float beta1, \
const float beta2, \
const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, \
float* max1, float* new_max1, \
@ -2760,6 +2820,8 @@ MAKE_optimizerStatic8bit1State(MOMENTUM, half)
MAKE_optimizerStatic8bit1State(MOMENTUM, float)
MAKE_optimizerStatic8bit1State(RMSPROP, half)
MAKE_optimizerStatic8bit1State(RMSPROP, float)
MAKE_optimizerStatic8bit1State(LION, half)
MAKE_optimizerStatic8bit1State(LION, float)
#define MAKE_PreconditionStatic8bit2State(oname, gtype) \
template __global__ void kPreconditionOptimizerStatic8bit2State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \
@ -2863,5 +2925,7 @@ MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8)

View File

@ -32,20 +32,20 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm,
const float beta1, const float eps, const float weight_decay,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER>
__global__ void kOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float eps, const float weight_decay,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int OPTIMIZER>
__global__ void
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
float *unorm,
const float beta1,
const float beta1, const float beta2,
const float eps, const int step,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
@ -57,7 +57,7 @@ template<typename T, int OPTIMIZER>
__global__ void
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
const float *unorm, const float max_unorm, const float param_norm,
const float beta1,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,

View File

@ -118,17 +118,28 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
case MOMENTUM:
case RMSPROP:
case ADAGRAD:
if(max_unorm > 0.0f)
{
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case LION:
// in lion, the momentum update after the parameter update
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
if(max_unorm > 0.0f)
{
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
break;
}
}
@ -162,12 +173,22 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
case RMSPROP:
case ADAGRAD:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case LION:
// in lion, the momentum update happens after the parameter update
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
default:
break;
}
@ -196,6 +217,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
case MOMENTUM:
case RMSPROP:
case ADAGRAD:
case LION:
num_blocks = n/BLOCKSIZE_1STATE;
num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1;
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
@ -705,6 +727,8 @@ MAKE_optimizer32bit(MOMENTUM, half)
MAKE_optimizer32bit(MOMENTUM, float)
MAKE_optimizer32bit(RMSPROP, half)
MAKE_optimizer32bit(RMSPROP, float)
MAKE_optimizer32bit(LION, half)
MAKE_optimizer32bit(LION, float)
MAKE_optimizer32bit(ADAGRAD, half)
MAKE_optimizer32bit(ADAGRAD, float)
@ -724,6 +748,8 @@ MAKE_optimizerStatic8bit(MOMENTUM, half)
MAKE_optimizerStatic8bit(MOMENTUM, float)
MAKE_optimizerStatic8bit(RMSPROP, half)
MAKE_optimizerStatic8bit(RMSPROP, float)
MAKE_optimizerStatic8bit(LION, half)
MAKE_optimizerStatic8bit(LION, float)
#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \
@ -736,6 +762,8 @@ MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(half, LION);
MAKE_optimizerStatic8bitBlockwise(float, LION);
MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);

View File

@ -70,6 +70,7 @@ typedef enum Optimizer_t
RMSPROP = 2,
LARS = 3,
ADAGRAD = 4,
LION = 5,
} Optimizer_t;
typedef enum Transform_t

View File

@ -33,6 +33,8 @@ MAKE_FUNC32(adam, ADAM, float, 32)
MAKE_FUNC32(adam, ADAM, half, 16)
MAKE_FUNC32(rmsprop, RMSPROP, float, 32)
MAKE_FUNC32(rmsprop, RMSPROP, half, 16)
MAKE_FUNC32(lion, LION, float, 32)
MAKE_FUNC32(lion, LION, half, 16)
MAKE_FUNC32(adagrad, ADAGRAD, float, 32)
MAKE_FUNC32(adagrad, ADAGRAD, half, 16)
@ -55,6 +57,8 @@ MAKE_FUNC8(momentum, MOMENTUM, float, 32)
MAKE_FUNC8(momentum, MOMENTUM, half, 16)
MAKE_FUNC8(rmsprop, RMSPROP, float, 32)
MAKE_FUNC8(rmsprop, RMSPROP, half, 16)
MAKE_FUNC8(lion, LION, float, 32)
MAKE_FUNC8(lion, LION, half, 16)
#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
@ -68,6 +72,8 @@ MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16)
MAKE_BLOCKWISE8(momentum, MOMENTUM, float, 32)
MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, 16)
MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, 32)
MAKE_BLOCKWISE8(lion, LION, half, 16)
MAKE_BLOCKWISE8(lion, LION, float, 32)
MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, 16)
MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32)
@ -161,6 +167,8 @@ extern "C"
MAKE_CFUNC32(momentum, half, 16)
MAKE_CFUNC32(rmsprop, float, 32)
MAKE_CFUNC32(rmsprop, half, 16)
MAKE_CFUNC32(lion, float, 32)
MAKE_CFUNC32(lion, half, 16)
MAKE_CFUNC32(adagrad, float, 32)
MAKE_CFUNC32(adagrad, half, 16)
@ -183,6 +191,8 @@ extern "C"
MAKE_CFUNC8(momentum, half, 16)
MAKE_CFUNC8(rmsprop, float, 32)
MAKE_CFUNC8(rmsprop, half, 16)
MAKE_CFUNC8(lion, float, 32)
MAKE_CFUNC8(lion, half, 16)
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
@ -196,6 +206,8 @@ extern "C"
MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32)
MAKE_CBLOCKWISE8(lion, LION, half, 16)
MAKE_CBLOCKWISE8(lion, LION, float, 32)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32)

View File

@ -12,10 +12,12 @@ URL116=https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installer
URL117=https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run
URL118=https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
URL120=https://developer.download.nvidia.com/compute/cuda/12.0.0/local_installers/cuda_12.0.0_525.60.13_linux.run
URL121=https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run
CUDA_VERSION=$1
BASE_PATH=$2
EXPORT_BASHRC=$3
if [[ -n "$CUDA_VERSION" ]]; then
if [[ "$CUDA_VERSION" -eq "92" ]]; then
@ -60,11 +62,14 @@ if [[ -n "$CUDA_VERSION" ]]; then
elif [[ "$CUDA_VERSION" -eq "120" ]]; then
URL=$URL120
FOLDER=cuda-12.0
elif [[ "$CUDA_VERSION" -eq "121" ]]; then
URL=$URL121
FOLDER=cuda-12.1
else
echo "argument error: No cuda version passed as input. Choose among: {111, 115}"
echo "argument error: No cuda version passed as input. Choose among versions 92 to 121"
fi
else
echo "argument error: No cuda version passed as input. Choose among: {111, 115}"
echo "argument error: No cuda version passed as input. Choose among versions 92 to 112"
fi
FILE=$(basename $URL)
@ -72,11 +77,13 @@ FILE=$(basename $URL)
if [[ -n "$CUDA_VERSION" ]]; then
echo $URL
echo $FILE
wget $URL
#wget $URL
bash $FILE --no-drm --no-man-page --override --toolkitpath=$BASE_PATH/$FOLDER/ --toolkit --silent
echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$BASE_PATH/$FOLDER/lib64/" >> ~/.bashrc
echo "export PATH=$PATH:$BASE_PATH/$FOLDER/bin/" >> ~/.bashrc
source ~/.bashrc
if [ "$EXPORT_BASHRC" -eq "1" ]; then
echo "export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:$BASE_PATH/$FOLDER/lib64" >> ~/.bashrc
echo "export PATH=\$PATH:$BASE_PATH/$FOLDER/bin" >> ~/.bashrc
source ~/.bashrc
fi
else
echo ""
fi

View File

@ -10,8 +10,8 @@ if [[ ! -z "${LD_LIBRARY_PATH}" ]]; then
fi
module unload cuda
module unload gcc
module unload cuda && echo "no module function available. Probably not on a slurm cluster."
module unload gcc && echo "no module function available. Probably not on a slurm cluster."
rm -rf dist build
make cleaneggs
@ -128,6 +128,16 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda120.so" ]; then
exit 64
fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-12.1
make cuda12x CUDA_VERSION=121
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-10.2
@ -241,5 +251,15 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda120_nocublaslt.so" ]; then
exit 64
fi
make clean
export CUDA_HOME=$BASE_PATH/cuda-12.1
make cuda12x_nomatmul CUDA_VERSION=121
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121_nocublaslt.so" ]; then
# Control will enter here if $DIRECTORY doesn't exist.
echo "Compilation unsuccessul!" 1>&2
exit 64
fi
python -m build
python -m twine upload dist/* --verbose

View File

@ -1,6 +1,6 @@
# No kernel image available
This problem arises with the cuda version loaded by bitsandbytes is not supported by your GPU, or if you pytorch CUDA version mismatches. So solve this problem you need to debug ``$LD_LIBRARY_PATH``, ``$CUDA_HOME``, ``$PATH``. You can print these via ``echo $PATH``. You should look for multiple paths to different CUDA versions. This can include versions in your anaconda path, for example ``$HOME/anaconda3/lib``. You can check those versions via ``ls -l $HOME/anaconda3/lib/*cuda*`` or equivalent paths. Look at the CUDA versions of files in these paths. Does it match with ``nvidia-smi``?
This problem arises with the cuda version loaded by bitsandbytes is not supported by your GPU, or if you pytorch CUDA version mismatches. To solve this problem you need to debug ``$LD_LIBRARY_PATH``, ``$CUDA_HOME``, ``$PATH``. You can print these via ``echo $PATH``. You should look for multiple paths to different CUDA versions. This can include versions in your anaconda path, for example ``$HOME/anaconda3/lib``. You can check those versions via ``ls -l $HOME/anaconda3/lib/*cuda*`` or equivalent paths. Look at the CUDA versions of files in these paths. Does it match with ``nvidia-smi``?
If you are feeling lucky, you can also try to compile the library from source. This can be still problematic if your PATH variables have multiple cuda versions. As such, it is recommended to figure out path conflicts before you proceed with compilation.

View File

@ -1 +1,2 @@
lion-pytorch
pytest

View File

@ -18,7 +18,7 @@ def read(fname):
setup(
name=f"bitsandbytes",
version=f"0.36.0-2",
version=f"0.38.0.post2",
author="Tim Dettmers",
author_email="dettmers@cs.washington.edu",
description="8-bit optimizers and matrix multiplication routines.",

View File

@ -5,95 +5,20 @@ import pytest
import bitsandbytes as bnb
from bitsandbytes.cuda_setup.main import (
CUDA_RUNTIME_LIB,
determine_cuda_runtime_lib_path,
evaluate_cuda_setup,
extract_candidate_paths,
)
"""
'LD_LIBRARY_PATH': ':/mnt/D/titus/local/cuda-11.1/lib64/'
'CONDA_EXE': '/mnt/D/titus/miniconda/bin/conda'
'LESSCLOSE': '/usr/bin/lesspipe %s %s'
'OLDPWD': '/mnt/D/titus/src'
'CONDA_PREFIX': '/mnt/D/titus/miniconda/envs/8-bit'
'SSH_AUTH_SOCK': '/mnt/D/titus/.ssh/ssh-agent.tim-uw.sock'
'CONDA_PREFIX_1': '/mnt/D/titus/miniconda'
'PWD': '/mnt/D/titus/src/8-bit'
'HOME': '/mnt/D/titus'
'CONDA_PYTHON_EXE': '/mnt/D/titus/miniconda/bin/python'
'CUDA_HOME': '/mnt/D/titus/local/cuda-11.1/'
'TMUX': '/tmp/tmux-1007/default,59286,1'
'XDG_DATA_DIRS': '/usr/local/share:/usr/share:/var/lib/snapd/desktop'
'SSH_TTY': '/dev/pts/0'
'MAIL': '/var/mail/titus'
'SHELL': '/bin/bash'
'DBUS_SESSION_BUS_ADDRESS': 'unix:path=/run/user/1007/bus'
'XDG_RUNTIME_DIR': '/run/user/1007'
'PATH': '/mnt/D/titus/miniconda/envs/8-bit/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin:/mnt/D/titus/local/cuda-11.1/bin'
'LESSOPEN': '| /usr/bin/lesspipe %s'
'_': '/mnt/D/titus/miniconda/envs/8-bit/bin/python'
# any that include 'CONDA' that are not 'CONDA_PREFIX'
# we search for
'CUDA_HOME': '/mnt/D/titus/local/cuda-11.1/'
"""
class InputAndExpectedOutput(NamedTuple):
input: str
output: str
HAPPY_PATH__LD_LIB_TEST_PATHS: List[InputAndExpectedOutput] = [
(
f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
(
f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
(
f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
(
f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
(
f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
(
f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
]
@pytest.fixture(params=HAPPY_PATH__LD_LIB_TEST_PATHS)
def happy_path_path_string(tmpdir, request):
for path in extract_candidate_paths(request.param):
test_dir.mkdir()
if CUDA_RUNTIME_LIB in path:
(test_input / CUDA_RUNTIME_LIB).touch()
UNHAPPY_PATH__LD_LIB_TEST_PATHS = [
f"a/b/c/{CUDA_RUNTIME_LIB}:d/e/f/{CUDA_RUNTIME_LIB}",
f"a/b/c/{CUDA_RUNTIME_LIB}:d/e/f/{CUDA_RUNTIME_LIB}:g/h/j/{CUDA_RUNTIME_LIB}",
]
def test_full_system():
def test_cuda_full_system():
## this only tests the cuda version and not compute capability
# if CONDA_PREFIX exists, it has priority before all other env variables
# but it does not contain the library directly, so we need to look at the a sub-folder
version = ""
if "CONDA_PREFIX" in os.environ:
ls_output, err = bnb.utils.execute_and_return(f'ls -l {os.environ["CONDA_PREFIX"]}/lib/libcudart.so')
ls_output, err = bnb.utils.execute_and_return(f'ls -l {os.environ["CONDA_PREFIX"]}/lib/libcudart.so.11.0')
major, minor, revision = (ls_output.split(" ")[-1].replace("libcudart.so.", "").split("."))
version = float(f"{major}.{minor}")

143
tests/test_linear8bitlt.py Normal file
View File

@ -0,0 +1,143 @@
import os
from contextlib import nullcontext
from itertools import product
from tempfile import TemporaryDirectory
import pytest
import torch
import bitsandbytes as bnb
from bitsandbytes import functional as F
from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout
from bitsandbytes.nn.modules import Linear8bitLt
# contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
@pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5),
reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs",
)
def test_layout_exact_match():
x = (torch.randn(14336 * 3, 14336) * 10).to(torch.int8).cuda()
for tile_size, order in ((8, 32), "col_turing"), ((32, 32), "col_ampere"):
transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
tile_indices = get_inverse_transform_indices(transform, tile_size)
cxb = transform(x)
torch.cuda.synchronize()
restored_x = undo_layout(cxb, tile_indices)
torch.cuda.synchronize()
assert restored_x.is_contiguous()
assert torch.all(torch.eq(restored_x, x))
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
def test_linear_no_igemmlt():
linear = torch.nn.Linear(1024, 3072)
x = torch.randn(3, 1024, dtype=torch.half)
linear_custom = Linear8bitLt(
linear.in_features,
linear.out_features,
linear.bias is not None,
has_fp16_weights=False,
threshold=6.0,
)
linear_custom.state.force_no_igemmlt = True
linear_custom.weight = bnb.nn.Int8Params(
linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
).to(linear.weight.dtype)
linear_custom.bias = linear.bias
linear_custom = linear_custom.cuda()
linear = linear.half().cuda()
x_ref = x.clone().cuda().requires_grad_(True)
x_ours = x.clone().cuda().requires_grad_(True)
fx_ref = linear(x_ref).float()
grad_proj = torch.randn_like(fx_ref)
(fx_ref * grad_proj).mean().backward()
fx_ours = linear_custom(x_ours).float()
(fx_ours * grad_proj).mean().backward()
assert torch.allclose(fx_ref, fx_ours, atol=0.02)
assert torch.allclose(x_ref.grad, x_ours.grad, atol=0.01)
assert not linear_custom.state.has_fp16_weights
assert linear_custom.state.CB is not None
assert linear_custom.state.CxB is None
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt",
list(product([False, True], [False, True], [False, True], [False, True])))
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt):
linear = torch.nn.Linear(32, 96)
x = torch.randn(3, 32, dtype=torch.half)
linear_custom = Linear8bitLt(
linear.in_features,
linear.out_features,
linear.bias is not None,
has_fp16_weights=has_fp16_weights,
threshold=6.0,
)
if force_no_igemmlt:
linear_custom.state.force_no_igemmlt = True
linear_custom.weight = bnb.nn.Int8Params(
linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights
)
linear_custom.bias = linear.bias
linear_custom = linear_custom.cuda()
if serialize_before_forward:
state_dict_8bit = linear_custom.state_dict()
x_first = x.clone().cuda().requires_grad_(True)
fx_first = linear_custom(x_first).float()
grad_proj = torch.randn_like(fx_first)
(fx_first * grad_proj).mean().backward()
if not serialize_before_forward:
state_dict_8bit = linear_custom.state_dict()
with TemporaryDirectory() as tmpdir:
state_path_8bit = os.path.join(tmpdir, "state_8bit.pth")
state_path = os.path.join(tmpdir, "state.pth")
torch.save(linear.state_dict(), state_path)
torch.save(state_dict_8bit, state_path_8bit)
if not has_fp16_weights:
assert os.path.getsize(state_path_8bit) < 0.5 * os.path.getsize(state_path)
new_state_dict = torch.load(state_path_8bit)
new_linear_custom = Linear8bitLt(
linear.in_features,
linear.out_features,
linear.bias is not None,
has_fp16_weights=has_fp16_weights,
threshold=6.0,
)
if force_no_igemmlt:
new_linear_custom.state.force_no_igemmlt = True
if deserialize_before_cuda:
with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError):
new_linear_custom.load_state_dict(new_state_dict, strict=True)
new_linear_custom = new_linear_custom.cuda()
if not deserialize_before_cuda:
new_linear_custom.load_state_dict(new_state_dict, strict=True)
x_second = x.clone().cuda().requires_grad_(True)
fx_second = new_linear_custom(x_second).float()
(fx_second * grad_proj).mean().backward()
# if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised
if has_fp16_weights or not deserialize_before_cuda:
assert torch.allclose(fx_first, fx_second, atol=1e-5)
assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5)

View File

@ -382,7 +382,7 @@ names = [f"threshold_{vals}" for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names)
@pytest.mark.parametrize("memory_efficient_backward", [True, False])
@pytest.mark.parametrize("memory_efficient_backward", [False])
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
l1 = (
bnb.nn.Linear8bitLt(

View File

@ -7,6 +7,8 @@ from itertools import product
from os.path import join
import pytest
from lion_pytorch import Lion
import torch
import bitsandbytes as bnb
@ -16,6 +18,13 @@ import bitsandbytes.functional as F
k = 20
def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0):
idx = torch.isclose(a, b, rtol, atol)
error_count = (idx == 0).sum().item()
if error_count > max_error_count:
print(f"Too many values not close: assert {error_count} < {max_error_count}")
torch.testing.assert_allclose(a, b, rtol, atol)
def get_temp_dir():
path = f"/tmp/autoswap/{str(uuid.uuid4())}"
@ -31,6 +40,7 @@ str2optimizers = {}
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
str2optimizers["momentum_pytorch"] = (
None,
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
@ -38,6 +48,7 @@ str2optimizers["momentum_pytorch"] = (
)
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers["lion"] = (Lion, bnb.optim.Lion)
str2optimizers["momentum"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
@ -54,6 +65,10 @@ str2optimizers["adam8bit"] = (
torch.optim.Adam,
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False),
)
str2optimizers["lion8bit"] = (
Lion,
lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False),
)
str2optimizers["momentum8bit"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
@ -71,6 +86,10 @@ str2optimizers["adam8bit_blockwise"] = (
torch.optim.Adam,
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True),
)
str2optimizers["lion8bit_blockwise"] = (
Lion,
lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True),
)
str2optimizers["momentum8bit_blockwise"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
@ -82,6 +101,7 @@ str2optimizers["rmsprop8bit_blockwise"] = (
str2statenames = {}
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["lion"] = [("exp_avg", "state1")]
str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lars"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
@ -90,6 +110,9 @@ str2statenames["adam8bit"] = [
("exp_avg", "state1", "qmap1", "max1"),
("exp_avg_sq", "state2", "qmap2", "max2"),
]
str2statenames["lion8bit"] = [
("exp_avg", "state1", "qmap1", "max1")
]
str2statenames["lamb8bit"] = [
("exp_avg", "state1", "qmap1", "max1"),
("exp_avg_sq", "state2", "qmap2", "max2"),
@ -98,6 +121,9 @@ str2statenames["adam8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["lion8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1")
]
str2statenames["momentum8bit"] = [
("momentum_buffer", "state1", "qmap1", "max1")
]
@ -113,7 +139,7 @@ str2statenames["rmsprop8bit_blockwise"] = [
dim1 = [1024]
dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16]
optimizer_names = ["adam", "momentum", "rmsprop", "lars"]
optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lion"]
values = list(product(dim1, dim2, gtype, optimizer_names))
names = [
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
@ -144,6 +170,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
bnb_optimizer.step()
torch_optimizer.step()
for name1, name2 in str2statenames[optim_name]:
torch.testing.assert_allclose(
torch_optimizer.state[p1][name1],
@ -152,7 +179,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
rtol=rtol,
)
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
# since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 10 errors for Lion
assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10)
if i % (k // 5) == 0 and i > 0:
path = get_temp_dir()
@ -162,14 +191,15 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
bnb_optimizer = str2optimizers[optim_name][1]([p2])
bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
rm_path(path)
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
# since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 10 errors for Lion
assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10)
for name1, name2 in str2statenames[optim_name]:
torch.testing.assert_allclose(
torch_optimizer.state[p1][name1],
bnb_optimizer.state[p2][name2],
atol=atol,
rtol=rtol,
)
# since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 10 errors for Lion
assert_most_approx_close(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2],
atol=atol, rtol=rtol,
max_error_count=10)
if gtype == torch.float16:
# the adam buffers should also be close because they are 32-bit
@ -241,9 +271,11 @@ dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16]
optimizer_names = [
"adam8bit",
"lion8bit",
"momentum8bit",
"rmsprop8bit",
"adam8bit_blockwise",
"lion8bit_blockwise",
"lars8bit",
"momentum8bit_blockwise",
"rmsprop8bit_blockwise",
@ -285,7 +317,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
bnb_optimizer.step()
torch_optimizer.step()
torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol)
# since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 5 errors for Lion
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=5)
dequant_states = []
for name1, name2, qmap, max_val in str2statenames[optim_name]:
@ -313,7 +347,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
dequant_states.append(s1.clone())
err = torch.abs(p1 - p2)
relerr = err / torch.abs(p1)
relerr = err / (torch.abs(p1)+1e-9)
assert err.mean() < 0.0001
assert relerr.mean() < 0.001
@ -367,9 +401,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
== 0
)
assert num_not_close.sum().item() < 20
torch.testing.assert_allclose(
p1, p2.float(), atol=patol, rtol=prtol
)
# since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 5 errors for Lion
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=5)
# the parameters diverge quickly. Here we keep them close
# together so we can test against the Adam error