Merge pull request #1 from TimDettmers/main

Update main branch
This commit is contained in:
Dmitry Baranchuk 2022-09-10 19:33:21 -07:00 committed by GitHub
commit 843ad0631c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 60 additions and 197 deletions

View File

@ -23,12 +23,12 @@ Resources:
1. Comment out torch.nn.Linear: ``#linear = torch.nn.Linear(...)``
2. Add bnb 8-bit linear light module: ``linear = bnb.nn.Linear8bitLt(...)`` (base arguments stay the same)
3. There are two modes:
- Mixed 8-bit training with 16-bit main weights. Pass the argument ``use_fp16_weights=True`` (default)
- Int8 inference. Pass the argument ``use_fp16_weights=False``
- Mixed 8-bit training with 16-bit main weights. Pass the argument ``has_fp16_weights=True`` (default)
- Int8 inference. Pass the argument ``has_fp16_weights=False``
4. To use the full LLM.int8() method, use the ``threshold=k`` argument. We recommend ``k=6.0``.
```python
# LLM.int8()
linear = bnb.nn.Linear8bitLt(dim1, dim2, bias=True, use_fp16_weights=False, threshold=6.0)
linear = bnb.nn.Linear8bitLt(dim1, dim2, bias=True, has_fp16_weights=False, threshold=6.0)
# inputs need to be fp16
out = linear(x.to(torch.float16))
```
@ -115,7 +115,8 @@ We thank Fabio Cannizzo for his work on [FastBinarySearch](https://github.com/fa
## How to cite us
If you found this library and found LLM.int8() useful, please consider citing our work:
```
```bibtex
@article{dettmers2022llmint8,
title={LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale},
author={Dettmers, Tim and Lewis, Mike and Belkada, Younes and Zettlemoyer, Luke},
@ -124,8 +125,9 @@ If you found this library and found LLM.int8() useful, please consider citing ou
}
```
For 8-bit optimizers or quantization routines please consider citing the following work.
```
For 8-bit optimizers or quantization routines, please consider citing the following work:
```bibtex
@article{dettmers2022optimizers,
title={8-bit Optimizers via Block-wise Quantization},
author={Dettmers, Tim and Lewis, Mike and Shleifer, Sam and Zettlemoyer, Luke},

View File

@ -12,7 +12,7 @@ from .autograd._functions import (
)
from .cextension import COMPILED_WITH_CUDA
from .nn import modules
from . import cuda_setup
from . import cuda_setup, utils
if COMPILED_WITH_CUDA:
from .optim import adam

View File

@ -3,8 +3,9 @@
# cli()
import os
import sys
import torch
from warnings import warn
import torch
HEADER_WIDTH = 60
@ -32,8 +33,6 @@ print()
from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL
from .cuda_setup.main import get_compute_capabilities, get_cuda_lib_handle
from .cuda_setup.env_vars import to_be_ignored
from .utils import print_stderr
print_header("POTENTIALLY LIBRARY-PATH-LIKE ENV VARS")
for k, v in os.environ.items():
@ -84,7 +83,7 @@ try:
except ImportError:
print()
print_stderr(
warn(
f"WARNING: {__package__} is currently running as CPU-only!\n"
"Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n"
f"If you think that this is so erroneously,\nplease report an issue!"

View File

@ -1,6 +1,5 @@
import operator
import torch
import bitsandbytes as bnb
import bitsandbytes.functional as F
from dataclasses import dataclass
@ -368,9 +367,6 @@ class MatMul8bitLt(torch.autograd.Function):
return grad_A, grad_B, None, grad_bias, None
matmul = MatMul8bitLt.apply
def matmul(
A: tensor,
B: tensor,

View File

@ -1,79 +0,0 @@
import ctypes
from dataclasses import dataclass, field
@dataclass
class CudaLibVals:
# code bits taken from
# https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
nGpus: ctypes.c_int = field(default=ctypes.c_int())
cc_major: ctypes.c_int = field(default=ctypes.c_int())
cc_minor: ctypes.c_int = field(default=ctypes.c_int())
device: ctypes.c_int = field(default=ctypes.c_int())
error_str: ctypes.c_char_p = field(default=ctypes.c_char_p())
cuda: ctypes.CDLL = field(init=False, repr=False)
ccs: List[str, ...] = field(init=False)
def _initialize_driver_API(self):
self.check_cuda_result(self.cuda.cuInit(0))
def _load_cuda_lib(self):
"""
1. find libcuda.so library (GPU driver) (/usr/lib)
init_device -> init variables -> call function by reference
"""
libnames = "libcuda.so"
for libname in libnames:
try:
self.cuda = ctypes.CDLL(libname)
except OSError:
continue
else:
break
else:
raise OSError("could not load any of: " + " ".join(libnames))
def call_cuda_func(self, function_obj, **kwargs):
CUDA_SUCCESS = 0 # constant taken from cuda.h
pass
# if (CUDA_SUCCESS := function_obj(
def _error_handle(cuda_lib_call_return_value):
"""
2. call extern C function to determine CC
(see https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
"""
CUDA_SUCCESS = 0 # constant taken from cuda.h
if cuda_lib_call_return_value != CUDA_SUCCESS:
self.cuda.cuGetErrorString(
cuda_lib_call_return_value,
ctypes.byref(self.error_str),
)
print("Count not initialize CUDA - failure!")
raise Exception("CUDA exception!")
return cuda_lib_call_return_value
def __post_init__(self):
self._load_cuda_lib()
self._initialize_driver_API()
self.check_cuda_result(
self.cuda, self.cuda.cuDeviceGetCount(ctypes.byref(self.nGpus))
)
tmp_ccs = []
for gpu_index in range(self.nGpus.value):
check_cuda_result(
self.cuda,
self.cuda.cuDeviceGet(ctypes.byref(self.device), gpu_index),
)
check_cuda_result(
self.cuda,
self.cuda.cuDeviceComputeCapability(
ctypes.byref(self.cc_major),
ctypes.byref(self.cc_minor),
self.device,
),
)
tmp_ccs.append(f"{self.cc_major.value}.{self.cc_minor.value}")
self.ccs = sorted(tmp_ccs, reverse=True)

View File

@ -17,9 +17,7 @@ evaluation:
"""
import ctypes
from pathlib import Path
from ..utils import execute_and_return
from .paths import determine_cuda_runtime_lib_path
@ -28,7 +26,7 @@ def check_cuda_result(cuda, result_val):
if result_val != 0:
error_str = ctypes.c_char_p()
cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
raise Exception(f"CUDA exception! Error code: {error_str.value.decode()}")
print(f"CUDA exception! Error code: {error_str.value.decode()}")
def get_cuda_version(cuda, cudart_path):
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
@ -57,7 +55,7 @@ def get_cuda_lib_handle():
cuda = ctypes.CDLL("libcuda.so")
except OSError:
# TODO: shouldn't we error or at least warn here?
raise Exception('CUDA SETUP: ERROR! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
print('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
return None
check_cuda_result(cuda, cuda.cuInit(0))
@ -80,7 +78,6 @@ def get_compute_capabilities(cuda):
cc_major = ctypes.c_int()
cc_minor = ctypes.c_int()
result = ctypes.c_int()
device = ctypes.c_int()
check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus)))
@ -119,6 +116,10 @@ def evaluate_cuda_setup():
print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
print('='*80)
binary_name = "libbitsandbytes_cpu.so"
#if not torch.cuda.is_available():
#print('No GPU detected. Loading CPU library...')
#return binary_name
cudart_path = determine_cuda_runtime_lib_path()
if cudart_path is None:
print(

View File

@ -2,23 +2,11 @@ from pathlib import Path
from typing import Set, Union
from warnings import warn
from ..utils import print_stderr
from .env_vars import get_potentially_lib_path_containing_env_vars
CUDA_RUNTIME_LIB: str = "libcudart.so"
def purge_unwanted_semicolon(tentative_path: Path) -> Path:
"""
Special function to handle the following exception:
__LMOD_REF_COUNT_PATH=/sw/cuda/11.6.2/bin:2;/mmfs1/home/dettmers/git/sched/bin:1;/mmfs1/home/dettmers/data/anaconda3/bin:1;/mmfs1/home/dettmers/data/anaconda3/condabin:1;/mmfs1/home/dettmers/.local/bin:1;/mmfs1/home/dettmers/bin:1;/usr/local/bin:1;/usr/bin:1;/usr/local/sbin:1;/usr/sbin:1;/mmfs1/home/dettmers/.fzf/bin:1;/mmfs1/home/dettmers/data/local/cuda-11.4/bin:1
"""
# if ';' in str(tentative_path):
# path_as_str, _ = str(tentative_path).split(';')
pass
def extract_candidate_paths(paths_list_candidate: str) -> Set[Path]:
return {Path(ld_path) for ld_path in paths_list_candidate.split(":") if ld_path}
@ -29,7 +17,7 @@ def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]:
}
if non_existent_directories:
print_stderr(
warn(
"WARNING: The following directories listed in your path were found to "
f"be non-existent: {non_existent_directories}"
)
@ -117,8 +105,6 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
if env_var not in {"CONDA_PREFIX", "LD_LIBRARY_PATH"}
}
cuda_runtime_libs = set()
for env_var, value in remaining_candidate_env_vars.items():
cuda_runtime_libs.update(find_cuda_lib_in(value))

View File

@ -5,7 +5,6 @@
import ctypes as ct
import operator
import random
import math
import torch
from typing import Tuple
@ -185,14 +184,9 @@ def create_dynamic_map(signed=True, n=7):
def get_special_format_str():
if not torch.cuda.is_available(): return 'col_turing'
major, minor = torch.cuda.get_device_capability()
if major < 7:
print(
f"Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!"
)
assert major >= 7
if major == 7:
if major <= 7:
return "col_turing"
elif major == 8:
return "col_ampere"
@ -248,23 +242,6 @@ def get_transform_func(dtype, orderA, orderOut, transpose=False):
return getattr(lib, name)
class GlobalData(object):
_instance = None
def __init__(self):
raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.data = {}
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls.__new__(cls)
cls._instance.initialize()
return cls._instance
def get_transform_buffer(
shape, dtype, device, to_order, from_order="row", transpose=False
):
@ -1685,21 +1662,6 @@ def double_quant(
return out_row, out_col, row_stats, col_stats, coo_tensor
def get_special_format_str():
major, minor = torch.cuda.get_device_capability()
if major < 7:
print(
f"Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!"
)
assert major >= 7
if major == 7: return 'col_turing'
elif major == 8: return 'col_ampere'
else: return 'col_turing'
def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
prev_device = pre_call(A.device)
if state is None: state = (A.shape, from_order)

View File

@ -5,13 +5,12 @@
from bitsandbytes.cextension import COMPILED_WITH_CUDA
if COMPILED_WITH_CUDA:
from .adam import Adam, Adam8bit, Adam32bit
from .adamw import AdamW, AdamW8bit, AdamW32bit
from .sgd import SGD, SGD8bit, SGD32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .lamb import LAMB, LAMB8bit, LAMB32bit
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
from .adam import Adam, Adam8bit, Adam32bit
from .adamw import AdamW, AdamW8bit, AdamW32bit
from .sgd import SGD, SGD8bit, SGD32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .lamb import LAMB, LAMB8bit, LAMB32bit
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
from .optimizer import GlobalOptimManager

View File

@ -1,6 +1,5 @@
import shlex
import subprocess
import sys
from typing import Tuple
@ -22,11 +21,3 @@ def execute_and_return(command_string: str) -> Tuple[str, str]:
std_out, std_err = execute_and_return_decoded_std_streams(command_string)
return std_out, std_err
def print_stderr(s: str) -> None:
print(s, file=sys.stderr)
def warn_of_missing_prerequisite(s: str) -> None:
print_stderr("WARNING, missing pre-requisite: " + s)

View File

@ -371,7 +371,11 @@ template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandl
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{
#ifdef NO_CUBLASLT
printf("ERROR: Your GPU does not support Int8 Matmul!");
cout << "" << endl;
cout << "=============================================" << endl;
cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl;
cout << "=============================================" << endl;
cout << "" << endl;
assert(false);
return 0;

View File

@ -18,7 +18,7 @@ def read(fname):
setup(
name=f"bitsandbytes",
version=f"0.32.1",
version=f"0.32.3",
author="Tim Dettmers",
author_email="dettmers@cs.washington.edu",
description="8-bit optimizers and matrix multiplication routines.",

View File

@ -40,6 +40,7 @@ names = [
ids=names,
)
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if not torch.cuda.is_available(): pytest.skip('No GPU found.')
if dim2 > 0:
dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
@ -306,6 +307,7 @@ def test_matmullt(
has_fp16_weights,
has_bias
):
if not torch.cuda.is_available(): pytest.skip('No GPU found.')
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")

View File

@ -1813,16 +1813,16 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
batch_size = 1
seqdim = 2048
seqdim = 1
values = []
values.append((batch_size, seqdim, 768, 4 * 768))
#values.append((batch_size, seqdim, 768, 4 * 768))
# values.append((batch_size, seqdim, 1024, 4*1024))
# values.append((batch_size, seqdim, 1536, 4*1536))
# values.append((batch_size, seqdim, 2048, 4*2048))
# values.append((batch_size, seqdim, 2560, 4*2560))
# values.append((batch_size, seqdim, 4096, 4*4096))
# values.append((batch_size, seqdim, 5140, 4*5140))
# values.append((batch_size, seqdim, 12288, 4*12288))
values.append((batch_size, seqdim, 12288, 4*12288))
names = [
"batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
]
@ -1830,6 +1830,7 @@ names = [
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_matmul(batch, seq, model, hidden):
iters = 128
formatB = F.get_special_format_str()
A = torch.randn(batch, seq, model, device="cuda").half()
@ -1848,28 +1849,33 @@ def test_bench_matmul(batch, seq, model, hidden):
linearMixedBit.eval()
# warmup
for i in range(100):
for i in range(iters):
torch.matmul(A, B.t())
torch.cuda.synchronize()
print("")
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
for i in range(iters):
torch.matmul(A, B.t())
torch.cuda.synchronize()
print(
f"pytorch: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
for i in range(iters):
bnb.matmul(A, B)
torch.cuda.synchronize()
print(
f"bnb lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul(A, B, threshold=6.0)
torch.cuda.synchronize()
print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
C32A, SA = F.transform(CA, "col32")
@ -1877,18 +1883,16 @@ def test_bench_matmul(batch, seq, model, hidden):
CxB, SB = F.transform(CB, to_order=formatB)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
for i in range(iters):
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
torch.cuda.synchronize()
print(
f"igemmlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
BA, statsB = F.vectorwise_quant(B, dim=1)
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
for i in range(iters):
A2 = A.view(-1, A.shape[-1]).contiguous()
CA, statsA = F.vectorwise_quant(A2, dim=1)
C32A, SA = F.nvidia_transform(CA, "col32")
@ -1896,15 +1900,13 @@ def test_bench_matmul(batch, seq, model, hidden):
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
torch.cuda.synchronize()
print(
f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
#print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
for i in range(iters):
A2 = A.view(-1, A.shape[-1]).contiguous()
CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
C32A, SA = F.nvidia_transform(CA, "col32")
@ -1912,14 +1914,12 @@ def test_bench_matmul(batch, seq, model, hidden):
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
out = Cout * statsB * statsA * (1.0 / (127 * 127))
torch.cuda.synchronize()
print(
f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
#print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
linear8bit(A)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
for i in range(iters):
linear8bit(A)
torch.cuda.synchronize()
print(
@ -1929,7 +1929,7 @@ def test_bench_matmul(batch, seq, model, hidden):
linearMixedBit(A)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
for i in range(iters):
linearMixedBit(A)
torch.cuda.synchronize()
print(