Merge branch 'main' into remove_unused_code

This commit is contained in:
Tim Dettmers 2022-09-05 16:29:25 -07:00 committed by GitHub
commit aca55881b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 55 additions and 64 deletions

View File

@ -23,12 +23,12 @@ Resources:
1. Comment out torch.nn.Linear: ``#linear = torch.nn.Linear(...)`` 1. Comment out torch.nn.Linear: ``#linear = torch.nn.Linear(...)``
2. Add bnb 8-bit linear light module: ``linear = bnb.nn.Linear8bitLt(...)`` (base arguments stay the same) 2. Add bnb 8-bit linear light module: ``linear = bnb.nn.Linear8bitLt(...)`` (base arguments stay the same)
3. There are two modes: 3. There are two modes:
- Mixed 8-bit training with 16-bit main weights. Pass the argument ``use_fp16_weights=True`` (default) - Mixed 8-bit training with 16-bit main weights. Pass the argument ``has_fp16_weights=True`` (default)
- Int8 inference. Pass the argument ``use_fp16_weights=False`` - Int8 inference. Pass the argument ``has_fp16_weights=False``
4. To use the full LLM.int8() method, use the ``threshold=k`` argument. We recommend ``k=6.0``. 4. To use the full LLM.int8() method, use the ``threshold=k`` argument. We recommend ``k=6.0``.
```python ```python
# LLM.int8() # LLM.int8()
linear = bnb.nn.Linear8bitLt(dim1, dim2, bias=True, use_fp16_weights=False, threshold=6.0) linear = bnb.nn.Linear8bitLt(dim1, dim2, bias=True, has_fp16_weights=False, threshold=6.0)
# inputs need to be fp16 # inputs need to be fp16
out = linear(x.to(torch.float16)) out = linear(x.to(torch.float16))
``` ```
@ -115,7 +115,8 @@ We thank Fabio Cannizzo for his work on [FastBinarySearch](https://github.com/fa
## How to cite us ## How to cite us
If you found this library and found LLM.int8() useful, please consider citing our work: If you found this library and found LLM.int8() useful, please consider citing our work:
```
```bibtex
@article{dettmers2022llmint8, @article{dettmers2022llmint8,
title={LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale}, title={LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale},
author={Dettmers, Tim and Lewis, Mike and Belkada, Younes and Zettlemoyer, Luke}, author={Dettmers, Tim and Lewis, Mike and Belkada, Younes and Zettlemoyer, Luke},
@ -124,8 +125,9 @@ If you found this library and found LLM.int8() useful, please consider citing ou
} }
``` ```
For 8-bit optimizers or quantization routines please consider citing the following work. For 8-bit optimizers or quantization routines, please consider citing the following work:
```
```bibtex
@article{dettmers2022optimizers, @article{dettmers2022optimizers,
title={8-bit Optimizers via Block-wise Quantization}, title={8-bit Optimizers via Block-wise Quantization},
author={Dettmers, Tim and Lewis, Mike and Shleifer, Sam and Zettlemoyer, Luke}, author={Dettmers, Tim and Lewis, Mike and Shleifer, Sam and Zettlemoyer, Luke},

View File

@ -26,7 +26,7 @@ def check_cuda_result(cuda, result_val):
if result_val != 0: if result_val != 0:
error_str = ctypes.c_char_p() error_str = ctypes.c_char_p()
cuda.cuGetErrorString(result_val, ctypes.byref(error_str)) cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
raise Exception(f"CUDA exception! Error code: {error_str.value.decode()}") print(f"CUDA exception! Error code: {error_str.value.decode()}")
def get_cuda_version(cuda, cudart_path): def get_cuda_version(cuda, cudart_path):
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
@ -55,7 +55,7 @@ def get_cuda_lib_handle():
cuda = ctypes.CDLL("libcuda.so") cuda = ctypes.CDLL("libcuda.so")
except OSError: except OSError:
# TODO: shouldn't we error or at least warn here? # TODO: shouldn't we error or at least warn here?
raise Exception('CUDA SETUP: ERROR! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!') print('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
return None return None
check_cuda_result(cuda, cuda.cuInit(0)) check_cuda_result(cuda, cuda.cuInit(0))
@ -116,6 +116,10 @@ def evaluate_cuda_setup():
print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link') print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
print('='*80) print('='*80)
binary_name = "libbitsandbytes_cpu.so" binary_name = "libbitsandbytes_cpu.so"
#if not torch.cuda.is_available():
#print('No GPU detected. Loading CPU library...')
#return binary_name
cudart_path = determine_cuda_runtime_lib_path() cudart_path = determine_cuda_runtime_lib_path()
if cudart_path is None: if cudart_path is None:
print( print(

View File

@ -184,14 +184,9 @@ def create_dynamic_map(signed=True, n=7):
def get_special_format_str(): def get_special_format_str():
if not torch.cuda.is_available(): return 'col_turing'
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
if major < 7: if major <= 7:
print(
f"Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!"
)
assert major >= 7
if major == 7:
return "col_turing" return "col_turing"
elif major == 8: elif major == 8:
return "col_ampere" return "col_ampere"
@ -1667,21 +1662,6 @@ def double_quant(
return out_row, out_col, row_stats, col_stats, coo_tensor return out_row, out_col, row_stats, col_stats, coo_tensor
def get_special_format_str():
major, minor = torch.cuda.get_device_capability()
if major < 7:
print(
f"Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!"
)
assert major >= 7
if major == 7: return 'col_turing'
elif major == 8: return 'col_ampere'
else: return 'col_turing'
def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
prev_device = pre_call(A.device) prev_device = pre_call(A.device)
if state is None: state = (A.shape, from_order) if state is None: state = (A.shape, from_order)

View File

@ -5,13 +5,12 @@
from bitsandbytes.cextension import COMPILED_WITH_CUDA from bitsandbytes.cextension import COMPILED_WITH_CUDA
if COMPILED_WITH_CUDA: from .adam import Adam, Adam8bit, Adam32bit
from .adam import Adam, Adam8bit, Adam32bit from .adamw import AdamW, AdamW8bit, AdamW32bit
from .adamw import AdamW, AdamW8bit, AdamW32bit from .sgd import SGD, SGD8bit, SGD32bit
from .sgd import SGD, SGD8bit, SGD32bit from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS from .lamb import LAMB, LAMB8bit, LAMB32bit
from .lamb import LAMB, LAMB8bit, LAMB32bit from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
from .optimizer import GlobalOptimManager from .optimizer import GlobalOptimManager

View File

@ -371,7 +371,11 @@ template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandl
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ {
#ifdef NO_CUBLASLT #ifdef NO_CUBLASLT
printf("ERROR: Your GPU does not support Int8 Matmul!"); cout << "" << endl;
cout << "=============================================" << endl;
cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl;
cout << "=============================================" << endl;
cout << "" << endl;
assert(false); assert(false);
return 0; return 0;

View File

@ -18,7 +18,7 @@ def read(fname):
setup( setup(
name=f"bitsandbytes", name=f"bitsandbytes",
version=f"0.32.1", version=f"0.32.2",
author="Tim Dettmers", author="Tim Dettmers",
author_email="dettmers@cs.washington.edu", author_email="dettmers@cs.washington.edu",
description="8-bit optimizers and matrix multiplication routines.", description="8-bit optimizers and matrix multiplication routines.",

View File

@ -40,6 +40,7 @@ names = [
ids=names, ids=names,
) )
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if not torch.cuda.is_available(): pytest.skip('No GPU found.')
if dim2 > 0: if dim2 > 0:
dim2 = dim2 - (dim2 % 16) dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16) dim3 = dim3 - (dim3 % 16)
@ -306,6 +307,7 @@ def test_matmullt(
has_fp16_weights, has_fp16_weights,
has_bias has_bias
): ):
if not torch.cuda.is_available(): pytest.skip('No GPU found.')
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda") outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")

View File

@ -1813,16 +1813,16 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
batch_size = 1 batch_size = 1
seqdim = 2048 seqdim = 1
values = [] values = []
values.append((batch_size, seqdim, 768, 4 * 768)) #values.append((batch_size, seqdim, 768, 4 * 768))
# values.append((batch_size, seqdim, 1024, 4*1024)) # values.append((batch_size, seqdim, 1024, 4*1024))
# values.append((batch_size, seqdim, 1536, 4*1536)) # values.append((batch_size, seqdim, 1536, 4*1536))
# values.append((batch_size, seqdim, 2048, 4*2048)) # values.append((batch_size, seqdim, 2048, 4*2048))
# values.append((batch_size, seqdim, 2560, 4*2560)) # values.append((batch_size, seqdim, 2560, 4*2560))
# values.append((batch_size, seqdim, 4096, 4*4096)) # values.append((batch_size, seqdim, 4096, 4*4096))
# values.append((batch_size, seqdim, 5140, 4*5140)) # values.append((batch_size, seqdim, 5140, 4*5140))
# values.append((batch_size, seqdim, 12288, 4*12288)) values.append((batch_size, seqdim, 12288, 4*12288))
names = [ names = [
"batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values "batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
] ]
@ -1830,6 +1830,7 @@ names = [
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_matmul(batch, seq, model, hidden): def test_bench_matmul(batch, seq, model, hidden):
iters = 128
formatB = F.get_special_format_str() formatB = F.get_special_format_str()
A = torch.randn(batch, seq, model, device="cuda").half() A = torch.randn(batch, seq, model, device="cuda").half()
@ -1848,28 +1849,33 @@ def test_bench_matmul(batch, seq, model, hidden):
linearMixedBit.eval() linearMixedBit.eval()
# warmup # warmup
for i in range(100): for i in range(iters):
torch.matmul(A, B.t()) torch.matmul(A, B.t())
torch.cuda.synchronize() torch.cuda.synchronize()
print("") print("")
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(100): for i in range(iters):
torch.matmul(A, B.t()) torch.matmul(A, B.t())
torch.cuda.synchronize() torch.cuda.synchronize()
print( print(
f"pytorch: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
) )
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(100): for i in range(iters):
bnb.matmul(A, B) bnb.matmul(A, B)
torch.cuda.synchronize() torch.cuda.synchronize()
print( print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
f"bnb lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
) torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul(A, B, threshold=6.0)
torch.cuda.synchronize()
print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
C32A, SA = F.transform(CA, "col32") C32A, SA = F.transform(CA, "col32")
@ -1877,18 +1883,16 @@ def test_bench_matmul(batch, seq, model, hidden):
CxB, SB = F.transform(CB, to_order=formatB) CxB, SB = F.transform(CB, to_order=formatB)
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(100): for i in range(iters):
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
torch.cuda.synchronize() torch.cuda.synchronize()
print( print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
f"igemmlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
BA, statsB = F.vectorwise_quant(B, dim=1) BA, statsB = F.vectorwise_quant(B, dim=1)
CxB, SB = F.nvidia_transform(CB, to_order=formatB) CxB, SB = F.nvidia_transform(CB, to_order=formatB)
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(100): for i in range(iters):
A2 = A.view(-1, A.shape[-1]).contiguous() A2 = A.view(-1, A.shape[-1]).contiguous()
CA, statsA = F.vectorwise_quant(A2, dim=1) CA, statsA = F.vectorwise_quant(A2, dim=1)
C32A, SA = F.nvidia_transform(CA, "col32") C32A, SA = F.nvidia_transform(CA, "col32")
@ -1896,15 +1900,13 @@ def test_bench_matmul(batch, seq, model, hidden):
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
torch.cuda.synchronize() torch.cuda.synchronize()
print( #print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
CxB, SB = F.nvidia_transform(CB, to_order=formatB) CxB, SB = F.nvidia_transform(CB, to_order=formatB)
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(100): for i in range(iters):
A2 = A.view(-1, A.shape[-1]).contiguous() A2 = A.view(-1, A.shape[-1]).contiguous()
CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear") CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
C32A, SA = F.nvidia_transform(CA, "col32") C32A, SA = F.nvidia_transform(CA, "col32")
@ -1912,14 +1914,12 @@ def test_bench_matmul(batch, seq, model, hidden):
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
out = Cout * statsB * statsA * (1.0 / (127 * 127)) out = Cout * statsB * statsA * (1.0 / (127 * 127))
torch.cuda.synchronize() torch.cuda.synchronize()
print( #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
linear8bit(A) linear8bit(A)
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(100): for i in range(iters):
linear8bit(A) linear8bit(A)
torch.cuda.synchronize() torch.cuda.synchronize()
print( print(
@ -1929,7 +1929,7 @@ def test_bench_matmul(batch, seq, model, hidden):
linearMixedBit(A) linearMixedBit(A)
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(100): for i in range(iters):
linearMixedBit(A) linearMixedBit(A)
torch.cuda.synchronize() torch.cuda.synchronize()
print( print(