Changed CUDA setup to use PyTorch default; added a weak test.

This commit is contained in:
Tim Dettmers 2023-07-13 23:58:41 -07:00
parent ac155f7415
commit 1ab6758b36
3 changed files with 65 additions and 106 deletions

View File

@ -38,5 +38,5 @@ except AttributeError as ex:
# print the setup details after checking for errors so we do not print twice
if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
setup.print_log_stack()
#if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
#setup.print_log_stack()

View File

@ -47,13 +47,14 @@ class CUDASetup:
if getattr(self, 'error', False): return
print(self.error)
self.error = True
if self.cuda is None:
self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected.')
if not self.cuda_available:
self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected or CUDA not installed.')
self.add_log_entry('CUDA SETUP: Solution 1): Your paths are probably not up-to-date. You can update them via: sudo ldconfig.')
self.add_log_entry('CUDA SETUP: Solution 2): If you do not have sudo rights, you can do the following:')
self.add_log_entry('CUDA SETUP: Solution 2a): Find the cuda library via: find / -name libcuda.so 2>/dev/null')
self.add_log_entry('CUDA SETUP: Solution 2b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_2a')
self.add_log_entry('CUDA SETUP: Solution 2c): For a permanent solution add the export from 2b into your .bashrc file, located at ~/.bashrc')
self.add_log_entry('CUDA SETUP: Solution 3): For a missing CUDA runtime library (libcudart.so), use `find / -name libcudart.so* and follow with step (2b)')
return
if self.cudart_path is None:
@ -98,20 +99,26 @@ class CUDASetup:
self.initialized = False
self.error = False
def manual_override(self):
if torch.cuda.is_available():
if 'CUDA_HOME' in os.environ and 'CUDA_VERSION' in os.environ:
if len(os.environ['CUDA_HOME']) > 0 and len(os.environ['CUDA_VERSION']) > 0:
self.binary_name = self.binary_name[:-6] + f'{os.environ["CUDA_VERSION"]}.so'
def run_cuda_setup(self):
self.initialized = True
self.cuda_setup_log = []
binary_name, cudart_path, cuda, cc, cuda_version_string = evaluate_cuda_setup()
binary_name, cudart_path, cc, cuda_version_string = evaluate_cuda_setup()
self.cudart_path = cudart_path
self.cuda = cuda
self.cuda_available = torch.cuda.is_available()
self.cc = cc
self.cuda_version_string = cuda_version_string
self.binary_name = binary_name
self.manual_override()
package_dir = Path(__file__).parent.parent
binary_path = package_dir / binary_name
print('bin', binary_path)
binary_path = package_dir / self.binary_name
try:
if not binary_path.exists():
@ -123,10 +130,12 @@ class CUDASetup:
self.add_log_entry('')
self.add_log_entry('='*48 + 'ERROR' + '='*37)
self.add_log_entry('CUDA SETUP: CUDA detection failed! Possible reasons:')
self.add_log_entry('1. CUDA driver not installed')
self.add_log_entry('2. CUDA not installed')
self.add_log_entry('3. You have multiple conflicting CUDA libraries')
self.add_log_entry('4. Required library not pre-compiled for this bitsandbytes release!')
self.add_log_entry('1. You need to manually override the PyTorch CUDA version. Please see: '
'"https://github.com/TimDettmers/bitsandbytes/blob/main/how_to_use_nonpytorch_cuda.md')
self.add_log_entry('2. CUDA driver not installed')
self.add_log_entry('3. CUDA not installed')
self.add_log_entry('4. You have multiple conflicting CUDA libraries')
self.add_log_entry('5. Required library not pre-compiled for this bitsandbytes release!')
self.add_log_entry('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.')
self.add_log_entry('CUDA SETUP: The CUDA version for the compile might depend on your conda install. Inspect CUDA version via `conda list | grep cuda`.')
self.add_log_entry('='*80)
@ -218,11 +227,13 @@ def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None:
if len(results_paths) > 1:
warning_msg = (
f"Found duplicate {CUDA_RUNTIME_LIBS} files: {results_paths}.. "
"We'll flip a coin and try one of these, in order to fail forward.\n"
"Either way, this might cause trouble in the future:\n"
"If you get `CUDA error: invalid device function` errors, the above "
"might be the cause and the solution is to make sure only one "
f"{CUDA_RUNTIME_LIBS} in the paths that we search based on your env.")
"We select the PyTorch default libcudart.so, which is {torch.version.cuda},"
"but this might missmatch with the CUDA version that is needed for bitsandbytes."
"To override this behavior set the CUDA_HOME environmental variable"
"For example, if you want to use the CUDA version wht the path"
"/usr/local/cuda-11.2/lib/libcudart.so as the default,"
"then add the following to your .bashrc:"
"export CUDA_HOME=/usr/local/cuda-11.2")
CUDASetup.get_instance().add_log_entry(warning_msg, is_warning=True)
@ -240,6 +251,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
"""
candidate_env_vars = get_potentially_lib_path_containing_env_vars()
cuda_runtime_libs = set()
if "CONDA_PREFIX" in candidate_env_vars:
conda_libs_path = Path(candidate_env_vars["CONDA_PREFIX"]) / "lib"
@ -247,7 +259,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
warn_in_case_of_duplicates(conda_cuda_libs)
if conda_cuda_libs:
return next(iter(conda_cuda_libs))
cuda_runtime_libs.update(conda_cuda_libs)
CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["CONDA_PREFIX"]} did not contain '
f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', is_warning=True)
@ -256,7 +268,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"])
if lib_ld_cuda_libs:
return next(iter(lib_ld_cuda_libs))
cuda_runtime_libs.update(lib_ld_cuda_libs)
warn_in_case_of_duplicates(lib_ld_cuda_libs)
CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain '
@ -277,13 +289,13 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
warn_in_case_of_duplicates(cuda_runtime_libs)
print(cuda_runtime_libs, flush=True)
return next(iter(cuda_runtime_libs)) if cuda_runtime_libs else None
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
def get_cuda_version(cuda, cudart_path):
if cuda is None: return None
def get_cuda_version():
major, minor = map(int, torch.version.cuda.split("."))
if major < 11:
@ -291,19 +303,7 @@ def get_cuda_version(cuda, cudart_path):
return f'{major}{minor}'
def get_cuda_lib_handle():
# 1. find libcuda.so library (GPU driver) (/usr/lib)
try:
cuda = ct.CDLL("libcuda.so")
except OSError:
CUDASetup.get_instance().add_log_entry('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
return None
return cuda
def get_compute_capabilities(cuda):
def get_compute_capabilities():
ccs = []
for i in range(torch.cuda.device_count()):
cc_major, cc_minor = torch.cuda.get_device_capability(torch.cuda.device(i))
@ -312,20 +312,6 @@ def get_compute_capabilities(cuda):
return ccs
# def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error
def get_compute_capability(cuda):
"""
Extracts the highest compute capbility from all available GPUs, as compute
capabilities are downwards compatible. If no GPUs are detected, it returns
None.
"""
if cuda is None: return None
# TODO: handle different compute capabilities; for now, take the max
ccs = get_compute_capabilities(cuda)
if ccs: return ccs[-1]
def evaluate_cuda_setup():
if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
print('')
@ -337,27 +323,15 @@ def evaluate_cuda_setup():
cuda_setup = CUDASetup.get_instance()
cudart_path = determine_cuda_runtime_lib_path()
cuda = get_cuda_lib_handle()
cc = get_compute_capability(cuda)
cuda_version_string = get_cuda_version(cuda, cudart_path)
ccs = get_compute_capabilities()
ccs.sort()
cc = ccs[-1] # we take the highest capability
cuda_version_string = get_cuda_version()
failure = False
if cudart_path is None:
failure = True
cuda_setup.add_log_entry("WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!", is_warning=True)
else:
cuda_setup.add_log_entry(f"CUDA SETUP: CUDA runtime path found: {cudart_path}")
cuda_setup.add_log_entry(f"CUDA SETUP: PyTorch settings found: CUDA_VERSION={cuda_version_string}, Highest Compute Capability: {cc}.")
cuda_setup.add_log_entry(f"CUDA SETUP: To manually override the PyTorch CUDA version please see:"
"https://github.com/TimDettmers/bitsandbytes/blob/main/how_to_use_nonpytorch_cuda.md")
if cc == '' or cc is None:
failure = True
cuda_setup.add_log_entry("WARNING: No GPU detected! Check your CUDA paths. Proceeding to load CPU-only library...", is_warning=True)
else:
cuda_setup.add_log_entry(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")
if cuda is None:
failure = True
else:
cuda_setup.add_log_entry(f'CUDA SETUP: Detected CUDA version {cuda_version_string}')
# 7.5 is the minimum CC vor cublaslt
has_cublaslt = is_cublasLt_compatible(cc)
@ -369,12 +343,10 @@ def evaluate_cuda_setup():
# we use ls -l instead of nvcc to determine the cuda version
# since most installations will have the libcudart.so installed, but not the compiler
if failure:
binary_name = "libbitsandbytes_cpu.so"
elif has_cublaslt:
if has_cublaslt:
binary_name = f"libbitsandbytes_cuda{cuda_version_string}.so"
else:
"if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so"
binary_name = f"libbitsandbytes_cuda{cuda_version_string}_nocublaslt.so"
return binary_name, cudart_path, cuda, cc, cuda_version_string
return binary_name, cudart_path, cc, cuda_version_string

View File

@ -1,40 +1,27 @@
import os
from typing import List, NamedTuple
import pytest
import torch
from pathlib import Path
# hardcoded test. Not good, but a sanity check for now
def test_manual_override():
manual_cuda_path = str(Path('/mmfs1/home/dettmers/data/local/cuda-12.2'))
pytorch_version = torch.version.cuda.replace('.', '')
assert pytorch_version != 122
os.environ['CUDA_HOME']='{manual_cuda_path}'
os.environ['CUDA_VERSION']='122'
assert str(manual_cuda_path) in os.environ['LD_LIBRARY_PATH']
import bitsandbytes as bnb
loaded_lib = bnb.cuda_setup.main.CUDASetup.get_instance().binary_name
assert loaded_lib == 'libbitsandbytes_cuda122.so'
import bitsandbytes as bnb
from bitsandbytes.cuda_setup.main import (
determine_cuda_runtime_lib_path,
evaluate_cuda_setup,
extract_candidate_paths,
)
def test_cuda_full_system():
## this only tests the cuda version and not compute capability
# if CONDA_PREFIX exists, it has priority before all other env variables
# but it does not contain the library directly, so we need to look at the a sub-folder
version = ""
if "CONDA_PREFIX" in os.environ:
ls_output, err = bnb.utils.execute_and_return(f'ls -l {os.environ["CONDA_PREFIX"]}/lib/libcudart.so.11.0')
major, minor, revision = (ls_output.split(" ")[-1].replace("libcudart.so.", "").split("."))
version = float(f"{major}.{minor}")
if version == "" and "LD_LIBRARY_PATH" in os.environ:
ld_path = os.environ["LD_LIBRARY_PATH"]
paths = ld_path.split(":")
version = ""
for p in paths:
if "cuda" in p:
idx = p.rfind("cuda-")
version = p[idx + 5 : idx + 5 + 4].replace("/", "")
version = float(version)
break
assert version > 0
binary_name, cudart_path, cuda, cc, cuda_version_string = evaluate_cuda_setup()
binary_name = binary_name.replace("libbitsandbytes_cuda", "")
assert binary_name.startswith(str(version).replace(".", ""))