diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 2374c35..4bc7bf7 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -1,9 +1,34 @@ import ctypes as ct import os from warnings import warn +from bitsandbytes.cuda_setup import evaluate_cuda_setup -lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so') +class CUDALibrary_Singleton(object): + _instance = None + + def __init__(self): + raise RuntimeError('Call get_instance() instead') + + def initialize(self): + self.context = {} + binary_name = evaluate_cuda_setup() + if not os.path.exists(os.path.dirname(__file__) + f'/{binary_name}'): + print(f'TODO: compile library for specific version: {binary_name}') + print('defaulting to libbitsandbytes.so') + self.lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so') + else: + self.lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + f'/{binary_name}') + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.initialize() + return cls._instance + + +lib = CUDALibrary_Singleton.get_instance().lib try: lib.cadam32bit_g32 lib.get_context.restype = ct.c_void_p diff --git a/bitsandbytes/cuda_setup.py b/bitsandbytes/cuda_setup.py index 48423b5..6f67275 100644 --- a/bitsandbytes/cuda_setup.py +++ b/bitsandbytes/cuda_setup.py @@ -23,6 +23,58 @@ from pathlib import Path from typing import Set, Union from .utils import warn_of_missing_prerequisite, print_err +import ctypes +import shlex +import subprocess + +def execute_and_return(strCMD): + proc = subprocess.Popen(shlex.split(strCMD), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out, err = proc.communicate() + out, err = out.decode("UTF-8").strip(), err.decode("UTF-8").strip() + return out, err + +def check_cuda_result(cuda, result_val): + if result_val != 0: + cuda.cuGetErrorString(result_val, ctypes.byref(error_str)) + print(f"Count not initialize CUDA - failure!") + raise Exception('CUDA excepion!') + return result_val + +# taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549 +def get_compute_capability(): + libnames = ('libcuda.so', 'libcuda.dylib', 'cuda.dll') + for libname in libnames: + try: + cuda = ctypes.CDLL(libname) + except OSError: + continue + else: + break + else: + raise OSError("could not load any of: " + ' '.join(libnames)) + + + nGpus = ctypes.c_int() + cc_major = ctypes.c_int() + cc_minor = ctypes.c_int() + + result = ctypes.c_int() + device = ctypes.c_int() + context = ctypes.c_void_p() + error_str = ctypes.c_char_p() + + result = check_cuda_result(cuda, cuda.cuInit(0)) + + result = check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus))) + ccs = [] + for i in range(nGpus.value): + result = check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i)) + result = check_cuda_result(cuda, cuda.cuDeviceComputeCapability(ctypes.byref(cc_major), ctypes.byref(cc_minor), device)) + ccs.append(f'{cc_major.value}.{cc_minor.value}') + + #TODO: handle different compute capabilities; for now, take the max + ccs.sort() + return ccs[-1] CUDA_RUNTIME_LIB: str = "libcudart.so" @@ -72,12 +124,30 @@ def get_cuda_runtime_lib_path( raise FileNotFoundError(err_msg) single_cuda_runtime_lib_dir = next(iter(cuda_runtime_libs)) - return ld_library_paths + return single_cuda_runtime_lib_dir def evaluate_cuda_setup(): - # - if paths faulty, return meaningful error - # - else: - # - determine CUDA version - # - determine capabilities - # - based on that set the default path - pass + cuda_path = get_cuda_runtime_lib_path() + cc = get_compute_capability() + binary_name = 'libbitsandbytes_cpu.so' + + has_gpu = cc != '' + if not has_gpu: + print('WARNING: No GPU detected! Check our CUDA paths. Processding to load CPU-only library...') + return binary_name + + has_cublaslt = cc in ['7.5', '8.0', '8.6'] + + # TODO: + # (1) Model missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) + # (2) Multiple CUDA versions installed + + cuda_home = str(Path(cuda_path).parent.parent) + ls_output, err = execute_and_return(f'{cuda_home}/bin/nvcc --version') + cuda_version = ls_output.split('\n')[3].split(',')[-1].strip().lower().replace('v', '') + major, minor, revision = cuda_version.split('.') + cuda_version_string = f'{major}{minor}' + + binary_name = f'libbitsandbytes_cuda{cuda_version_string}_{("cublaslt" if has_cublaslt else "")}.so' + + return binary_name diff --git a/install_cuda.sh b/install_cuda.sh deleted file mode 100644 index 6a4ff0c..0000000 --- a/install_cuda.sh +++ /dev/null @@ -1,5 +0,0 @@ -wget https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run -bash cuda_11.1.1_455.32.00_linux.run --no-drm --no-man-page --override --installpath=~/local --librarypath=~/local/lib --toolkitpath=~/local/cuda-11.1/ --toolkit --silent -echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/local/cuda-11.1/lib64/" >> ~/.bashrc -echo "export PATH=$PATH:~/local/cuda-11.1/bin/" >> ~/.bashrc -source ~/.bashrc diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index 96ee6c5..72aa3c7 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -1,4 +1,5 @@ import pytest +import os from typing import List @@ -16,6 +17,7 @@ HAPPY_PATH__LD_LIB_TEST_PATHS: List[tuple[str,str]] = [ (f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:", f"dir/with/{CUDA_RUNTIME_LIB}"), (f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"), (f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir", f"dir/with/{CUDA_RUNTIME_LIB}"), + (f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so", f"dir/with/{CUDA_RUNTIME_LIB}"), ] @@ -64,3 +66,21 @@ def test_get_cuda_runtime_lib_path__non_existent_dir(capsys, tmp_path): match in std_err for match in {"WARNING", "non-existent"} ) + +def test_full_system(): + ## this only tests the cuda version and not compute capability + ld_path = os.environ['LD_LIBRARY_PATH'] + paths = ld_path.split(':') + version = '' + for p in paths: + if 'cuda' in p: + idx = p.rfind('cuda-') + version = p[idx+5:idx+5+4].replace('/', '') + version = float(version) + break + + binary_name = evaluate_cuda_setup() + binary_name = binary_name.replace('libbitsandbytes_cuda', '') + assert binary_name.startswith(str(version).replace('.', '')) + +