Added full env variable search; CONDA_PREFIX priority.

This commit is contained in:
Tim Dettmers 2022-08-01 19:22:41 -07:00
parent c4fe6c69a3
commit 8bf3e9faab
3 changed files with 51 additions and 38 deletions

View File

@ -19,11 +19,11 @@ evaluation:
""" """
import ctypes import ctypes
from os import environ as env import os
from pathlib import Path from pathlib import Path
from typing import Set, Union from typing import Set, Union
from .utils import print_err, warn_of_missing_prerequisite from .utils import print_err, warn_of_missing_prerequisite, execute_and_return
def check_cuda_result(cuda, result_val): def check_cuda_result(cuda, result_val):
@ -88,22 +88,11 @@ def tokenize_paths(paths: str) -> Set[Path]:
return {Path(ld_path) for ld_path in paths.split(":") if ld_path} return {Path(ld_path) for ld_path in paths.split(":") if ld_path}
def get_cuda_runtime_lib_path( def resolve_env_variable(env_var):
# TODO: replace this with logic for all paths in env vars paths: Set[Path] = tokenize_paths(env_var)
LD_LIBRARY_PATH: Union[str, None] = env.get("LD_LIBRARY_PATH")
) -> Union[Path, None]:
"""# TODO: add doc-string"""
if not LD_LIBRARY_PATH:
warn_of_missing_prerequisite(
"LD_LIBRARY_PATH is completely missing from environment!"
)
return None
ld_library_paths: Set[Path] = tokenize_paths(LD_LIBRARY_PATH)
non_existent_directories: Set[Path] = { non_existent_directories: Set[Path] = {
path for path in ld_library_paths if not path.exists() path for path in paths if not path.exists()
} }
if non_existent_directories: if non_existent_directories:
@ -114,7 +103,7 @@ def get_cuda_runtime_lib_path(
cuda_runtime_libs: Set[Path] = { cuda_runtime_libs: Set[Path] = {
path / CUDA_RUNTIME_LIB path / CUDA_RUNTIME_LIB
for path in ld_library_paths for path in paths
if (path / CUDA_RUNTIME_LIB).is_file() if (path / CUDA_RUNTIME_LIB).is_file()
} - non_existent_directories } - non_existent_directories
@ -123,19 +112,35 @@ def get_cuda_runtime_lib_path(
f"Found duplicate {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.." f"Found duplicate {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.."
) )
raise FileNotFoundError(err_msg) raise FileNotFoundError(err_msg)
elif len(cuda_runtime_libs) == 0: return None
else: return next(iter(cuda_runtime_libs)) # for now just return the first
elif len(cuda_runtime_libs) < 1: def get_cuda_runtime_lib_path() -> Union[Path, None]:
"""# TODO: add doc-string"""
cuda_runtime_libs = []
if 'CONDA_PREFIX' in os.environ:
lib_conda_path = f'{os.environ["CONDA_PREFIX"]}/lib/'
print(lib_conda_path)
cuda_runtime_libs.append(resolve_env_variable(lib_conda_path))
if len(cuda_runtime_libs) == 1: return cuda_runtime_libs[0]
for var in os.environ:
cuda_runtime_libs.append(resolve_env_variable(var))
if len(cuda_runtime_libs) < 1:
err_msg = ( err_msg = (
f"Did not find {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.." f"Did not find {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.."
) )
raise FileNotFoundError(err_msg) raise FileNotFoundError(err_msg)
single_cuda_runtime_lib_dir = next(iter(cuda_runtime_libs)) return cuda_runtime_libs.pop()
return single_cuda_runtime_lib_dir
def evaluate_cuda_setup(): def evaluate_cuda_setup():
cuda_path = get_cuda_runtime_lib_path() cuda_path = get_cuda_runtime_lib_path()
print(f'CUDA SETUP: CUDA path found: {cuda_path}')
cc = get_compute_capability() cc = get_compute_capability()
binary_name = "libbitsandbytes_cpu.so" binary_name = "libbitsandbytes_cpu.so"
@ -152,13 +157,10 @@ def evaluate_cuda_setup():
# (2) Multiple CUDA versions installed # (2) Multiple CUDA versions installed
cuda_home = str(Path(cuda_path).parent.parent) cuda_home = str(Path(cuda_path).parent.parent)
ls_output, err = execute_and_return(f"{cuda_home}/bin/nvcc --version") ls_output, err = execute_and_return(f"ls -l {cuda_path}")
cuda_version = ( major, minor, revision = ls_output.split(' ')[-1].replace('libcudart.so.', '').split('.')
ls_output.split("\n")[3].split(",")[-1].strip().lower().replace("v", "")
)
major, minor, revision = cuda_version.split(".")
cuda_version_string = f"{major}{minor}" cuda_version_string = f"{major}{minor}"
binary_name = f'libbitsandbytes_cuda{cuda_version_string}_{("cublaslt" if has_cublaslt else "")}.so' binary_name = f'libbitsandbytes_cuda{cuda_version_string}{("" if has_cublaslt else "_nocublaslt")}.so'
return binary_name return binary_name

View File

@ -2,6 +2,7 @@ import sys
import shlex import shlex
import subprocess import subprocess
from typing import Tuple
def execute_and_return(command_string: str) -> Tuple[str, str]: def execute_and_return(command_string: str) -> Tuple[str, str]:
def _decode(subprocess_err_out_tuple): def _decode(subprocess_err_out_tuple):
@ -19,7 +20,7 @@ def execute_and_return(command_string: str) -> Tuple[str, str]:
).communicate() ).communicate()
) )
std_out, std_err = execute_and_return_decoded_std_streams() std_out, std_err = execute_and_return_decoded_std_streams(command_string)
return std_out, std_err return std_out, std_err

View File

@ -1,7 +1,8 @@
import os import os
from typing import List, NamedTuple
import pytest import pytest
import bitsandbytes as bnb
from typing import List, NamedTuple
from bitsandbytes.cuda_setup import ( from bitsandbytes.cuda_setup import (
CUDA_RUNTIME_LIB, CUDA_RUNTIME_LIB,
@ -91,16 +92,25 @@ def test_get_cuda_runtime_lib_path__non_existent_dir(capsys, tmp_path):
def test_full_system(): def test_full_system():
## this only tests the cuda version and not compute capability ## this only tests the cuda version and not compute capability
ld_path = os.environ["LD_LIBRARY_PATH"] version = ''
paths = ld_path.split(":") if 'CONDA_PREFIX' in os.environ:
version = "" ls_output, err = bnb.utils.execute_and_return(f'ls -l {os.environ["CONDA_PREFIX"]}/lib/libcudart.so')
for p in paths: major, minor, revision = ls_output.split(' ')[-1].replace('libcudart.so.', '').split('.')
if "cuda" in p: version = float(f'{major}.{minor}')
idx = p.rfind("cuda-")
version = p[idx + 5 : idx + 5 + 4].replace("/", "")
version = float(version)
break
if version == '' and 'LD_LIBRARY_PATH':
ld_path = os.environ["LD_LIBRARY_PATH"]
paths = ld_path.split(":")
version = ""
for p in paths:
if "cuda" in p:
idx = p.rfind("cuda-")
version = p[idx + 5 : idx + 5 + 4].replace("/", "")
version = float(version)
break
assert version > 0
binary_name = evaluate_cuda_setup() binary_name = evaluate_cuda_setup()
binary_name = binary_name.replace("libbitsandbytes_cuda", "") binary_name = binary_name.replace("libbitsandbytes_cuda", "")
assert binary_name.startswith(str(version).replace(".", "")) assert binary_name.startswith(str(version).replace(".", ""))