From 336e24696cdadedbeb7612e33eb6512ff3e1f9c9 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 2 Jan 2023 03:31:43 -0800 Subject: [PATCH] CUDASetup only executed once + fixed circular import. --- bitsandbytes/cextension.py | 119 +------------ bitsandbytes/cuda_setup/__init__.py | 6 - bitsandbytes/cuda_setup/main.py | 258 ++++++++++++++++++++++++++-- bitsandbytes/cuda_setup/paths.py | 119 ------------- tests/test_cuda_setup_evaluator.py | 2 +- 5 files changed, 248 insertions(+), 256 deletions(-) delete mode 100644 bitsandbytes/cuda_setup/paths.py diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index d140f4c..00ee587 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -1,122 +1,17 @@ import ctypes as ct +import torch + from pathlib import Path from warnings import warn -import torch +from bitsandbytes.cuda_setup.main import CUDASetup -class CUDASetup: - _instance = None +setup = CUDASetup.get_instance() +if setup.initialized != True: + setup.run_cuda_setup() - def __init__(self): - raise RuntimeError("Call get_instance() instead") - - def generate_instructions(self): - if self.cuda is None: - self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected.') - self.add_log_entry('CUDA SETUP: Solution 1): Your paths are probably not up-to-date. You can update them via: sudo ldconfig.') - self.add_log_entry('CUDA SETUP: Solution 2): If you do not have sudo rights, you can do the following:') - self.add_log_entry('CUDA SETUP: Solution 2a): Find the cuda library via: find / -name libcuda.so 2>/dev/null') - self.add_log_entry('CUDA SETUP: Solution 2b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_2a') - self.add_log_entry('CUDA SETUP: Solution 2c): For a permanent solution add the export from 2b into your .bashrc file, located at ~/.bashrc') - return - - if self.cudart_path is None: - self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA runtime library was not detected.') - self.add_log_entry('CUDA SETUP: Solution 1: To solve the issue the libcudart.so location needs to be added to the LD_LIBRARY_PATH variable') - self.add_log_entry('CUDA SETUP: Solution 1a): Find the cuda runtime library via: find / -name libcudart.so 2>/dev/null') - self.add_log_entry('CUDA SETUP: Solution 1b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_1a') - self.add_log_entry('CUDA SETUP: Solution 1c): For a permanent solution add the export from 1b into your .bashrc file, located at ~/.bashrc') - self.add_log_entry('CUDA SETUP: Solution 2: If no library was found in step 1a) you need to install CUDA.') - self.add_log_entry('CUDA SETUP: Solution 2a): Download CUDA install script: wget https://github.com/TimDettmers/bitsandbytes/blob/main/cuda_install.sh') - self.add_log_entry('CUDA SETUP: Solution 2b): Install desired CUDA version to desired location. The syntax is bash cuda_install.sh CUDA_VERSION PATH_TO_INSTALL_INTO.') - self.add_log_entry('CUDA SETUP: Solution 2b): For example, "bash cuda_install.sh 113 ~/local/" will download CUDA 11.3 and install into the folder ~/local') - return - - make_cmd = f'CUDA_VERSION={self.cuda_version_string}' - if len(self.cuda_version_string) < 3: - make_cmd += ' make cuda92' - elif self.cuda_version_string == '110': - make_cmd += ' make cuda110' - elif self.cuda_version_string[:2] == '11' and int(self.cuda_version_string[2]) > 0: - make_cmd += ' make cuda11x' - - has_cublaslt = self.cc in ["7.5", "8.0", "8.6"] - if not has_cublaslt: - make_cmd += '_nomatmul' - - self.add_log_entry('CUDA SETUP: Something unexpected happened. Please compile from source:') - self.add_log_entry('git clone git@github.com:TimDettmers/bitsandbytes.git') - self.add_log_entry('cd bitsandbytes') - self.add_log_entry(make_cmd) - self.add_log_entry('python setup.py install') - - def initialize(self): - self.has_printed = False - self.lib = None - self.run_cuda_setup() - - def run_cuda_setup(self): - self.initialized = True - self.cuda_setup_log = [] - - from .cuda_setup.main import evaluate_cuda_setup - binary_name, cudart_path, cuda, cc, cuda_version_string = evaluate_cuda_setup() - self.cudart_path = cudart_path - self.cuda = cuda - self.cc = cc - self.cuda_version_string = cuda_version_string - - package_dir = Path(__file__).parent - binary_path = package_dir / binary_name - - try: - if not binary_path.exists(): - self.add_log_entry(f"CUDA SETUP: Required library version not found: {binary_name}. Maybe you need to compile it from source?") - legacy_binary_name = "libbitsandbytes.so" - self.add_log_entry(f"CUDA SETUP: Defaulting to {legacy_binary_name}...") - binary_path = package_dir / legacy_binary_name - if not binary_path.exists(): - self.add_log_entry('') - self.add_log_entry('='*48 + 'ERROR' + '='*37) - self.add_log_entry('CUDA SETUP: CUDA detection failed! Possible reasons:') - self.add_log_entry('1. CUDA driver not installed') - self.add_log_entry('2. CUDA not installed') - self.add_log_entry('3. You have multiple conflicting CUDA libraries') - self.add_log_entry('4. Required library not pre-compiled for this bitsandbytes release!') - self.add_log_entry('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.') - self.add_log_entry('='*80) - self.add_log_entry('') - self.generate_instructions() - self.print_log_stack() - raise Exception('CUDA SETUP: Setup Failed!') - self.lib = ct.cdll.LoadLibrary(binary_path) - else: - self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...") - self.lib = ct.cdll.LoadLibrary(binary_path) - except Exception as ex: - self.add_log_entry(str(ex)) - self.print_log_stack() - - def add_log_entry(self, msg, is_warning=False): - self.cuda_setup_log.append((msg, is_warning)) - - def print_log_stack(self): - for msg, is_warning in self.cuda_setup_log: - if is_warning: - warn(msg) - else: - print(msg) - - @classmethod - def get_instance(cls): - if cls._instance is None: - cls._instance = cls.__new__(cls) - cls._instance.initialize() - return cls._instance - - -lib = CUDASetup.get_instance().lib +lib = setup.lib try: if lib is None and torch.cuda.is_available(): CUDASetup.get_instance().generate_instructions() diff --git a/bitsandbytes/cuda_setup/__init__.py b/bitsandbytes/cuda_setup/__init__.py index e781b9d..e69de29 100644 --- a/bitsandbytes/cuda_setup/__init__.py +++ b/bitsandbytes/cuda_setup/__init__.py @@ -1,6 +0,0 @@ -from .main import evaluate_cuda_setup -from .paths import ( - CUDA_RUNTIME_LIB, - determine_cuda_runtime_lib_path, - extract_candidate_paths, -) diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py index 61e584b..7dd6ba4 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/cuda_setup/main.py @@ -16,21 +16,243 @@ evaluation: - based on that set the default path """ -import ctypes +import ctypes as ct import os - +import errno import torch -from bitsandbytes.cextension import CUDASetup +from pathlib import Path +from typing import Set, Union +from .env_vars import get_potentially_lib_path_containing_env_vars -from .paths import determine_cuda_runtime_lib_path +CUDA_RUNTIME_LIB: str = "libcudart.so" + +class CUDASetup: + _instance = None + + def __init__(self): + raise RuntimeError("Call get_instance() instead") + + def generate_instructions(self): + if self.cuda is None: + self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected.') + self.add_log_entry('CUDA SETUP: Solution 1): Your paths are probably not up-to-date. You can update them via: sudo ldconfig.') + self.add_log_entry('CUDA SETUP: Solution 2): If you do not have sudo rights, you can do the following:') + self.add_log_entry('CUDA SETUP: Solution 2a): Find the cuda library via: find / -name libcuda.so 2>/dev/null') + self.add_log_entry('CUDA SETUP: Solution 2b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_2a') + self.add_log_entry('CUDA SETUP: Solution 2c): For a permanent solution add the export from 2b into your .bashrc file, located at ~/.bashrc') + return + + if self.cudart_path is None: + self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA runtime library was not detected.') + self.add_log_entry('CUDA SETUP: Solution 1: To solve the issue the libcudart.so location needs to be added to the LD_LIBRARY_PATH variable') + self.add_log_entry('CUDA SETUP: Solution 1a): Find the cuda runtime library via: find / -name libcudart.so 2>/dev/null') + self.add_log_entry('CUDA SETUP: Solution 1b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_1a') + self.add_log_entry('CUDA SETUP: Solution 1c): For a permanent solution add the export from 1b into your .bashrc file, located at ~/.bashrc') + self.add_log_entry('CUDA SETUP: Solution 2: If no library was found in step 1a) you need to install CUDA.') + self.add_log_entry('CUDA SETUP: Solution 2a): Download CUDA install script: wget https://github.com/TimDettmers/bitsandbytes/blob/main/cuda_install.sh') + self.add_log_entry('CUDA SETUP: Solution 2b): Install desired CUDA version to desired location. The syntax is bash cuda_install.sh CUDA_VERSION PATH_TO_INSTALL_INTO.') + self.add_log_entry('CUDA SETUP: Solution 2b): For example, "bash cuda_install.sh 113 ~/local/" will download CUDA 11.3 and install into the folder ~/local') + return + + make_cmd = f'CUDA_VERSION={self.cuda_version_string}' + if len(self.cuda_version_string) < 3: + make_cmd += ' make cuda92' + elif self.cuda_version_string == '110': + make_cmd += ' make cuda110' + elif self.cuda_version_string[:2] == '11' and int(self.cuda_version_string[2]) > 0: + make_cmd += ' make cuda11x' + + has_cublaslt = self.cc in ["7.5", "8.0", "8.6"] + if not has_cublaslt: + make_cmd += '_nomatmul' + + self.add_log_entry('CUDA SETUP: Something unexpected happened. Please compile from source:') + self.add_log_entry('git clone git@github.com:TimDettmers/bitsandbytes.git') + self.add_log_entry('cd bitsandbytes') + self.add_log_entry(make_cmd) + self.add_log_entry('python setup.py install') + + def initialize(self): + self.has_printed = False + self.lib = None + self.initialized = False + + def run_cuda_setup(self): + self.initialized = True + self.cuda_setup_log = [] + + binary_name, cudart_path, cuda, cc, cuda_version_string = evaluate_cuda_setup() + self.cudart_path = cudart_path + self.cuda = cuda + self.cc = cc + self.cuda_version_string = cuda_version_string + + package_dir = Path(__file__).parent.parent + binary_path = package_dir / binary_name + + try: + if not binary_path.exists(): + self.add_log_entry(f"CUDA SETUP: Required library version not found: {binary_name}. Maybe you need to compile it from source?") + legacy_binary_name = "libbitsandbytes.so" + self.add_log_entry(f"CUDA SETUP: Defaulting to {legacy_binary_name}...") + binary_path = package_dir / legacy_binary_name + if not binary_path.exists(): + self.add_log_entry('') + self.add_log_entry('='*48 + 'ERROR' + '='*37) + self.add_log_entry('CUDA SETUP: CUDA detection failed! Possible reasons:') + self.add_log_entry('1. CUDA driver not installed') + self.add_log_entry('2. CUDA not installed') + self.add_log_entry('3. You have multiple conflicting CUDA libraries') + self.add_log_entry('4. Required library not pre-compiled for this bitsandbytes release!') + self.add_log_entry('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.') + self.add_log_entry('='*80) + self.add_log_entry('') + self.generate_instructions() + self.print_log_stack() + raise Exception('CUDA SETUP: Setup Failed!') + self.lib = ct.cdll.LoadLibrary(binary_path) + else: + self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...") + self.lib = ct.cdll.LoadLibrary(binary_path) + except Exception as ex: + self.add_log_entry(str(ex)) + self.print_log_stack() + + def add_log_entry(self, msg, is_warning=False): + self.cuda_setup_log.append((msg, is_warning)) + + def print_log_stack(self): + for msg, is_warning in self.cuda_setup_log: + if is_warning: + warn(msg) + else: + print(msg) + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.initialize() + return cls._instance + + + +def extract_candidate_paths(paths_list_candidate: str) -> Set[Path]: + return {Path(ld_path) for ld_path in paths_list_candidate.split(":") if ld_path} + + +def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]: + existent_directories: Set[Path] = set() + for path in candidate_paths: + try: + if path.exists(): + existent_directories.add(path) + except OSError as exc: + if exc.errno != errno.ENAMETOOLONG: + raise exc + + non_existent_directories: Set[Path] = candidate_paths - existent_directories + if non_existent_directories: + CUDASetup.get_instance().add_log_entry("WARNING: The following directories listed in your path were found to " + f"be non-existent: {non_existent_directories}", is_warning=True) + + return existent_directories + + +def get_cuda_runtime_lib_paths(candidate_paths: Set[Path]) -> Set[Path]: + return { + path / CUDA_RUNTIME_LIB + for path in candidate_paths + if (path / CUDA_RUNTIME_LIB).is_file() + } + + +def resolve_paths_list(paths_list_candidate: str) -> Set[Path]: + """ + Searches a given environmental var for the CUDA runtime library, + i.e. `libcudart.so`. + """ + return remove_non_existent_dirs(extract_candidate_paths(paths_list_candidate)) + + +def find_cuda_lib_in(paths_list_candidate: str) -> Set[Path]: + return get_cuda_runtime_lib_paths( + resolve_paths_list(paths_list_candidate) + ) + + +def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None: + if len(results_paths) > 1: + warning_msg = ( + f"Found duplicate {CUDA_RUNTIME_LIB} files: {results_paths}.. " + "We'll flip a coin and try one of these, in order to fail forward.\n" + "Either way, this might cause trouble in the future:\n" + "If you get `CUDA error: invalid device function` errors, the above " + "might be the cause and the solution is to make sure only one " + f"{CUDA_RUNTIME_LIB} in the paths that we search based on your env.") + CUDASetup.get_instance().add_log_entry(warning_msg, is_warning=True) + + +def determine_cuda_runtime_lib_path() -> Union[Path, None]: + """ + Searches for a cuda installations, in the following order of priority: + 1. active conda env + 2. LD_LIBRARY_PATH + 3. any other env vars, while ignoring those that + - are known to be unrelated (see `bnb.cuda_setup.env_vars.to_be_ignored`) + - don't contain the path separator `/` + + If multiple libraries are found in part 3, we optimistically try one, + while giving a warning message. + """ + candidate_env_vars = get_potentially_lib_path_containing_env_vars() + + if "CONDA_PREFIX" in candidate_env_vars: + conda_libs_path = Path(candidate_env_vars["CONDA_PREFIX"]) / "lib" + + conda_cuda_libs = find_cuda_lib_in(str(conda_libs_path)) + warn_in_case_of_duplicates(conda_cuda_libs) + + if conda_cuda_libs: + return next(iter(conda_cuda_libs)) + + CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["CONDA_PREFIX"]} did not contain ' + f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...', is_warning=True) + + if "LD_LIBRARY_PATH" in candidate_env_vars: + lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"]) + + if lib_ld_cuda_libs: + return next(iter(lib_ld_cuda_libs)) + warn_in_case_of_duplicates(lib_ld_cuda_libs) + + CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain ' + f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...', is_warning=True) + + remaining_candidate_env_vars = { + env_var: value for env_var, value in candidate_env_vars.items() + if env_var not in {"CONDA_PREFIX", "LD_LIBRARY_PATH"} + } + + cuda_runtime_libs = set() + for env_var, value in remaining_candidate_env_vars.items(): + cuda_runtime_libs.update(find_cuda_lib_in(value)) + + if len(cuda_runtime_libs) == 0: + CUDASetup.get_instance().add_log_entry('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching /usr/local/cuda/lib64...') + cuda_runtime_libs.update(find_cuda_lib_in('/usr/local/cuda/lib64')) + + warn_in_case_of_duplicates(cuda_runtime_libs) + + return next(iter(cuda_runtime_libs)) if cuda_runtime_libs else None def check_cuda_result(cuda, result_val): # 3. Check for CUDA errors if result_val != 0: - error_str = ctypes.c_char_p() - cuda.cuGetErrorString(result_val, ctypes.byref(error_str)) + error_str = ct.c_char_p() + cuda.cuGetErrorString(result_val, ct.byref(error_str)) CUDASetup.get_instance().add_log_entry(f"CUDA exception! Error code: {error_str.value.decode()}") @@ -39,13 +261,13 @@ def get_cuda_version(cuda, cudart_path): if cuda is None: return None try: - cudart = ctypes.CDLL(cudart_path) + cudart = ct.CDLL(cudart_path) except OSError: CUDASetup.get_instance().add_log_entry(f'ERROR: libcudart.so could not be read from path: {cudart_path}!') return None - version = ctypes.c_int() - check_cuda_result(cuda, cudart.cudaRuntimeGetVersion(ctypes.byref(version))) + version = ct.c_int() + check_cuda_result(cuda, cudart.cudaRuntimeGetVersion(ct.byref(version))) version = int(version.value) major = version//1000 minor = (version-(major*1000))//10 @@ -59,7 +281,7 @@ def get_cuda_version(cuda, cudart_path): def get_cuda_lib_handle(): # 1. find libcuda.so library (GPU driver) (/usr/lib) try: - cuda = ctypes.CDLL("libcuda.so") + cuda = ct.CDLL("libcuda.so") except OSError: CUDASetup.get_instance().add_log_entry('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!') return None @@ -79,18 +301,18 @@ def get_compute_capabilities(cuda): # bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549 """ - nGpus = ctypes.c_int() - cc_major = ctypes.c_int() - cc_minor = ctypes.c_int() + nGpus = ct.c_int() + cc_major = ct.c_int() + cc_minor = ct.c_int() - device = ctypes.c_int() + device = ct.c_int() - check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus))) + check_cuda_result(cuda, cuda.cuDeviceGetCount(ct.byref(nGpus))) ccs = [] for i in range(nGpus.value): - check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i)) - ref_major = ctypes.byref(cc_major) - ref_minor = ctypes.byref(cc_minor) + check_cuda_result(cuda, cuda.cuDeviceGet(ct.byref(device), i)) + ref_major = ct.byref(cc_major) + ref_minor = ct.byref(cc_minor) # 2. call extern C function to determine CC check_cuda_result(cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device)) ccs.append(f"{cc_major.value}.{cc_minor.value}") diff --git a/bitsandbytes/cuda_setup/paths.py b/bitsandbytes/cuda_setup/paths.py deleted file mode 100644 index 1c100db..0000000 --- a/bitsandbytes/cuda_setup/paths.py +++ /dev/null @@ -1,119 +0,0 @@ -import errno -from pathlib import Path -from typing import Set, Union - -from bitsandbytes.cextension import CUDASetup - -from .env_vars import get_potentially_lib_path_containing_env_vars - -CUDA_RUNTIME_LIB: str = "libcudart.so" - - -def extract_candidate_paths(paths_list_candidate: str) -> Set[Path]: - return {Path(ld_path) for ld_path in paths_list_candidate.split(":") if ld_path} - - -def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]: - existent_directories: Set[Path] = set() - for path in candidate_paths: - try: - if path.exists(): - existent_directories.add(path) - except OSError as exc: - if exc.errno != errno.ENAMETOOLONG: - raise exc - - non_existent_directories: Set[Path] = candidate_paths - existent_directories - if non_existent_directories: - CUDASetup.get_instance().add_log_entry("WARNING: The following directories listed in your path were found to " - f"be non-existent: {non_existent_directories}", is_warning=True) - - return existent_directories - - -def get_cuda_runtime_lib_paths(candidate_paths: Set[Path]) -> Set[Path]: - return { - path / CUDA_RUNTIME_LIB - for path in candidate_paths - if (path / CUDA_RUNTIME_LIB).is_file() - } - - -def resolve_paths_list(paths_list_candidate: str) -> Set[Path]: - """ - Searches a given environmental var for the CUDA runtime library, - i.e. `libcudart.so`. - """ - return remove_non_existent_dirs(extract_candidate_paths(paths_list_candidate)) - - -def find_cuda_lib_in(paths_list_candidate: str) -> Set[Path]: - return get_cuda_runtime_lib_paths( - resolve_paths_list(paths_list_candidate) - ) - - -def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None: - if len(results_paths) > 1: - warning_msg = ( - f"Found duplicate {CUDA_RUNTIME_LIB} files: {results_paths}.. " - "We'll flip a coin and try one of these, in order to fail forward.\n" - "Either way, this might cause trouble in the future:\n" - "If you get `CUDA error: invalid device function` errors, the above " - "might be the cause and the solution is to make sure only one " - f"{CUDA_RUNTIME_LIB} in the paths that we search based on your env.") - CUDASetup.get_instance().add_log_entry(warning_msg, is_warning=True) - - -def determine_cuda_runtime_lib_path() -> Union[Path, None]: - """ - Searches for a cuda installations, in the following order of priority: - 1. active conda env - 2. LD_LIBRARY_PATH - 3. any other env vars, while ignoring those that - - are known to be unrelated (see `bnb.cuda_setup.env_vars.to_be_ignored`) - - don't contain the path separator `/` - - If multiple libraries are found in part 3, we optimistically try one, - while giving a warning message. - """ - candidate_env_vars = get_potentially_lib_path_containing_env_vars() - - if "CONDA_PREFIX" in candidate_env_vars: - conda_libs_path = Path(candidate_env_vars["CONDA_PREFIX"]) / "lib" - - conda_cuda_libs = find_cuda_lib_in(str(conda_libs_path)) - warn_in_case_of_duplicates(conda_cuda_libs) - - if conda_cuda_libs: - return next(iter(conda_cuda_libs)) - - CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["CONDA_PREFIX"]} did not contain ' - f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...', is_warning=True) - - if "LD_LIBRARY_PATH" in candidate_env_vars: - lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"]) - - if lib_ld_cuda_libs: - return next(iter(lib_ld_cuda_libs)) - warn_in_case_of_duplicates(lib_ld_cuda_libs) - - CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain ' - f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...', is_warning=True) - - remaining_candidate_env_vars = { - env_var: value for env_var, value in candidate_env_vars.items() - if env_var not in {"CONDA_PREFIX", "LD_LIBRARY_PATH"} - } - - cuda_runtime_libs = set() - for env_var, value in remaining_candidate_env_vars.items(): - cuda_runtime_libs.update(find_cuda_lib_in(value)) - - if len(cuda_runtime_libs) == 0: - CUDASetup.get_instance().add_log_entry('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching /usr/local/cuda/lib64...') - cuda_runtime_libs.update(find_cuda_lib_in('/usr/local/cuda/lib64')) - - warn_in_case_of_duplicates(cuda_runtime_libs) - - return next(iter(cuda_runtime_libs)) if cuda_runtime_libs else None diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index 7edc01f..c0da1d3 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -4,7 +4,7 @@ from typing import List, NamedTuple import pytest import bitsandbytes as bnb -from bitsandbytes.cuda_setup import ( +from bitsandbytes.cuda_setup.main import ( CUDA_RUNTIME_LIB, determine_cuda_runtime_lib_path, evaluate_cuda_setup,