Merge branch 'cuda-bin-switch-and-cli' of github.com:TimDettmers/bitsandbytes into cuda-bin-switch-and-cli

This commit is contained in:
Tim Dettmers 2022-08-16 10:57:10 -07:00
commit 111b876449
6 changed files with 29 additions and 12 deletions

View File

@ -1,10 +1,15 @@
from dataclasses import dataclass import operator
import torch import torch
import math
import bitsandbytes as bnb import bitsandbytes as bnb
import bitsandbytes.functional as F import bitsandbytes.functional as F
from dataclasses import dataclass
from functools import reduce # Required in Python 3
# math.prod not compatible with python < 3.8
def prod(iterable):
return reduce(operator.mul, iterable, 1)
tensor = torch.Tensor tensor = torch.Tensor
""" """
@ -12,8 +17,6 @@ tensor = torch.Tensor
This is particularly important for small models where outlier features This is particularly important for small models where outlier features
are less systematic and occur with low frequency. are less systematic and occur with low frequency.
""" """
class GlobalOutlierPooler(object): class GlobalOutlierPooler(object):
_instance = None _instance = None
@ -201,7 +204,7 @@ class MatMul8bitLt(torch.autograd.Function):
def forward(ctx, A, B, out=None, state=MatmulLtState()): def forward(ctx, A, B, out=None, state=MatmulLtState()):
# default to pytorch behavior if inputs are empty # default to pytorch behavior if inputs are empty
ctx.is_empty = False ctx.is_empty = False
if math.prod(A.shape) == 0: if prod(A.shape) == 0:
ctx.is_empty = True ctx.is_empty = True
ctx.A = A ctx.A = A
ctx.B = B ctx.B = B

View File

@ -45,6 +45,9 @@ def get_cuda_version(cuda, cudart_path):
major = version//1000 major = version//1000
minor = (version-(major*1000))//10 minor = (version-(major*1000))//10
if major < 11:
print('CUDA SETUP: CUDA version lower than 11 are currenlty not supported!')
return f'{major}{minor}' return f'{major}{minor}'
@ -110,6 +113,10 @@ def get_compute_capability(cuda):
def evaluate_cuda_setup(): def evaluate_cuda_setup():
print('')
print('='*35 + 'BUG REPORT' + '='*35)
print('Welcome to bitsandbytes. For bug reports, please use this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
print('='*80)
binary_name = "libbitsandbytes_cpu.so" binary_name = "libbitsandbytes_cpu.so"
cudart_path = determine_cuda_runtime_lib_path() cudart_path = determine_cuda_runtime_lib_path()
if cudart_path is None: if cudart_path is None:
@ -121,6 +128,7 @@ def evaluate_cuda_setup():
print(f"CUDA SETUP: CUDA path found: {cudart_path}") print(f"CUDA SETUP: CUDA path found: {cudart_path}")
cuda = get_cuda_lib_handle() cuda = get_cuda_lib_handle()
cc = get_compute_capability(cuda) cc = get_compute_capability(cuda)
print(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")
cuda_version_string = get_cuda_version(cuda, cudart_path) cuda_version_string = get_cuda_version(cuda, cudart_path)

View File

@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import ctypes as ct import ctypes as ct
import operator
import random import random
import math import math
import torch import torch
@ -11,6 +12,11 @@ from typing import Tuple
from torch import Tensor from torch import Tensor
from .cextension import COMPILED_WITH_CUDA, lib from .cextension import COMPILED_WITH_CUDA, lib
from functools import reduce # Required in Python 3
# math.prod not compatible with python < 3.8
def prod(iterable):
return reduce(operator.mul, iterable, 1)
name2qmap = {} name2qmap = {}
@ -326,8 +332,8 @@ def nvidia_transform(
dim1 = ct.c_int32(shape[0]) dim1 = ct.c_int32(shape[0])
dim2 = ct.c_int32(shape[1]) dim2 = ct.c_int32(shape[1])
elif ld is not None: elif ld is not None:
n = math.prod(shape) n = prod(shape)
dim1 = math.prod([shape[i] for i in ld]) dim1 = prod([shape[i] for i in ld])
dim2 = ct.c_int32(n // dim1) dim2 = ct.c_int32(n // dim1)
dim1 = ct.c_int32(dim1) dim1 = ct.c_int32(dim1)
else: else:
@ -1314,7 +1320,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
m = shapeA[0] * shapeA[1] m = shapeA[0] * shapeA[1]
rows = n = shapeB[0] rows = n = shapeB[0]
assert math.prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}' assert prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}'
# if the tensor is empty, return a transformed empty tensor with the right dimensions # if the tensor is empty, return a transformed empty tensor with the right dimensions
if shapeA[0] == 0 and dimsA == 2: if shapeA[0] == 0 and dimsA == 2:

View File

@ -65,7 +65,7 @@ if [[ -n "$CUDA_VERSION" ]]; then
echo $URL echo $URL
echo $FILE echo $FILE
wget $URL wget $URL
bash $FILE --no-drm --no-man-page --override --installpath=~/local --librarypath=$BASE_PATH/lib --toolkitpath=$BASE_PATH/$FOLDER/ --toolkit --silent bash $FILE --no-drm --no-man-page --override --toolkitpath=$BASE_PATH/$FOLDER/ --toolkit --silent
echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$BASE_PATH/$FOLDER/lib64/" >> ~/.bashrc echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$BASE_PATH/$FOLDER/lib64/" >> ~/.bashrc
echo "export PATH=$PATH:$BASE_PATH/$FOLDER/bin/" >> ~/.bashrc echo "export PATH=$PATH:$BASE_PATH/$FOLDER/bin/" >> ~/.bashrc
source ~/.bashrc source ~/.bashrc

View File

@ -202,4 +202,4 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda117_nocublaslt.so" ]; then
fi fi
python -m build python -m build
python -m twine upload dist/* --verbose --repository testpypi python -m twine upload dist/* --verbose

View File

@ -18,7 +18,7 @@ def read(fname):
setup( setup(
name=f"bitsandbytes", name=f"bitsandbytes",
version=f"0.31.4", version=f"0.31.8",
author="Tim Dettmers", author="Tim Dettmers",
author_email="dettmers@cs.washington.edu", author_email="dettmers@cs.washington.edu",
description="8-bit optimizers and matrix multiplication routines.", description="8-bit optimizers and matrix multiplication routines.",