Added full env variable search; CONDA_PREFIX priority.
This commit is contained in:
parent
c4fe6c69a3
commit
8bf3e9faab
|
@ -19,11 +19,11 @@ evaluation:
|
|||
"""
|
||||
|
||||
import ctypes
|
||||
from os import environ as env
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Set, Union
|
||||
|
||||
from .utils import print_err, warn_of_missing_prerequisite
|
||||
from .utils import print_err, warn_of_missing_prerequisite, execute_and_return
|
||||
|
||||
|
||||
def check_cuda_result(cuda, result_val):
|
||||
|
@ -88,22 +88,11 @@ def tokenize_paths(paths: str) -> Set[Path]:
|
|||
return {Path(ld_path) for ld_path in paths.split(":") if ld_path}
|
||||
|
||||
|
||||
def get_cuda_runtime_lib_path(
|
||||
# TODO: replace this with logic for all paths in env vars
|
||||
LD_LIBRARY_PATH: Union[str, None] = env.get("LD_LIBRARY_PATH")
|
||||
) -> Union[Path, None]:
|
||||
"""# TODO: add doc-string"""
|
||||
|
||||
if not LD_LIBRARY_PATH:
|
||||
warn_of_missing_prerequisite(
|
||||
"LD_LIBRARY_PATH is completely missing from environment!"
|
||||
)
|
||||
return None
|
||||
|
||||
ld_library_paths: Set[Path] = tokenize_paths(LD_LIBRARY_PATH)
|
||||
def resolve_env_variable(env_var):
|
||||
paths: Set[Path] = tokenize_paths(env_var)
|
||||
|
||||
non_existent_directories: Set[Path] = {
|
||||
path for path in ld_library_paths if not path.exists()
|
||||
path for path in paths if not path.exists()
|
||||
}
|
||||
|
||||
if non_existent_directories:
|
||||
|
@ -114,7 +103,7 @@ def get_cuda_runtime_lib_path(
|
|||
|
||||
cuda_runtime_libs: Set[Path] = {
|
||||
path / CUDA_RUNTIME_LIB
|
||||
for path in ld_library_paths
|
||||
for path in paths
|
||||
if (path / CUDA_RUNTIME_LIB).is_file()
|
||||
} - non_existent_directories
|
||||
|
||||
|
@ -123,19 +112,35 @@ def get_cuda_runtime_lib_path(
|
|||
f"Found duplicate {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.."
|
||||
)
|
||||
raise FileNotFoundError(err_msg)
|
||||
elif len(cuda_runtime_libs) == 0: return None
|
||||
else: return next(iter(cuda_runtime_libs)) # for now just return the first
|
||||
|
||||
elif len(cuda_runtime_libs) < 1:
|
||||
def get_cuda_runtime_lib_path() -> Union[Path, None]:
|
||||
"""# TODO: add doc-string"""
|
||||
|
||||
cuda_runtime_libs = []
|
||||
if 'CONDA_PREFIX' in os.environ:
|
||||
lib_conda_path = f'{os.environ["CONDA_PREFIX"]}/lib/'
|
||||
print(lib_conda_path)
|
||||
cuda_runtime_libs.append(resolve_env_variable(lib_conda_path))
|
||||
|
||||
if len(cuda_runtime_libs) == 1: return cuda_runtime_libs[0]
|
||||
|
||||
for var in os.environ:
|
||||
cuda_runtime_libs.append(resolve_env_variable(var))
|
||||
|
||||
if len(cuda_runtime_libs) < 1:
|
||||
err_msg = (
|
||||
f"Did not find {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.."
|
||||
)
|
||||
raise FileNotFoundError(err_msg)
|
||||
|
||||
single_cuda_runtime_lib_dir = next(iter(cuda_runtime_libs))
|
||||
return single_cuda_runtime_lib_dir
|
||||
return cuda_runtime_libs.pop()
|
||||
|
||||
|
||||
def evaluate_cuda_setup():
|
||||
cuda_path = get_cuda_runtime_lib_path()
|
||||
print(f'CUDA SETUP: CUDA path found: {cuda_path}')
|
||||
cc = get_compute_capability()
|
||||
binary_name = "libbitsandbytes_cpu.so"
|
||||
|
||||
|
@ -152,13 +157,10 @@ def evaluate_cuda_setup():
|
|||
# (2) Multiple CUDA versions installed
|
||||
|
||||
cuda_home = str(Path(cuda_path).parent.parent)
|
||||
ls_output, err = execute_and_return(f"{cuda_home}/bin/nvcc --version")
|
||||
cuda_version = (
|
||||
ls_output.split("\n")[3].split(",")[-1].strip().lower().replace("v", "")
|
||||
)
|
||||
major, minor, revision = cuda_version.split(".")
|
||||
ls_output, err = execute_and_return(f"ls -l {cuda_path}")
|
||||
major, minor, revision = ls_output.split(' ')[-1].replace('libcudart.so.', '').split('.')
|
||||
cuda_version_string = f"{major}{minor}"
|
||||
|
||||
binary_name = f'libbitsandbytes_cuda{cuda_version_string}_{("cublaslt" if has_cublaslt else "")}.so'
|
||||
binary_name = f'libbitsandbytes_cuda{cuda_version_string}{("" if has_cublaslt else "_nocublaslt")}.so'
|
||||
|
||||
return binary_name
|
||||
|
|
|
@ -2,6 +2,7 @@ import sys
|
|||
import shlex
|
||||
import subprocess
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
def execute_and_return(command_string: str) -> Tuple[str, str]:
|
||||
def _decode(subprocess_err_out_tuple):
|
||||
|
@ -19,7 +20,7 @@ def execute_and_return(command_string: str) -> Tuple[str, str]:
|
|||
).communicate()
|
||||
)
|
||||
|
||||
std_out, std_err = execute_and_return_decoded_std_streams()
|
||||
std_out, std_err = execute_and_return_decoded_std_streams(command_string)
|
||||
return std_out, std_err
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import os
|
||||
from typing import List, NamedTuple
|
||||
|
||||
import pytest
|
||||
import bitsandbytes as bnb
|
||||
|
||||
from typing import List, NamedTuple
|
||||
|
||||
from bitsandbytes.cuda_setup import (
|
||||
CUDA_RUNTIME_LIB,
|
||||
|
@ -91,16 +92,25 @@ def test_get_cuda_runtime_lib_path__non_existent_dir(capsys, tmp_path):
|
|||
|
||||
def test_full_system():
|
||||
## this only tests the cuda version and not compute capability
|
||||
ld_path = os.environ["LD_LIBRARY_PATH"]
|
||||
paths = ld_path.split(":")
|
||||
version = ""
|
||||
for p in paths:
|
||||
if "cuda" in p:
|
||||
idx = p.rfind("cuda-")
|
||||
version = p[idx + 5 : idx + 5 + 4].replace("/", "")
|
||||
version = float(version)
|
||||
break
|
||||
version = ''
|
||||
if 'CONDA_PREFIX' in os.environ:
|
||||
ls_output, err = bnb.utils.execute_and_return(f'ls -l {os.environ["CONDA_PREFIX"]}/lib/libcudart.so')
|
||||
major, minor, revision = ls_output.split(' ')[-1].replace('libcudart.so.', '').split('.')
|
||||
version = float(f'{major}.{minor}')
|
||||
|
||||
|
||||
if version == '' and 'LD_LIBRARY_PATH':
|
||||
ld_path = os.environ["LD_LIBRARY_PATH"]
|
||||
paths = ld_path.split(":")
|
||||
version = ""
|
||||
for p in paths:
|
||||
if "cuda" in p:
|
||||
idx = p.rfind("cuda-")
|
||||
version = p[idx + 5 : idx + 5 + 4].replace("/", "")
|
||||
version = float(version)
|
||||
break
|
||||
|
||||
assert version > 0
|
||||
binary_name = evaluate_cuda_setup()
|
||||
binary_name = binary_name.replace("libbitsandbytes_cuda", "")
|
||||
assert binary_name.startswith(str(version).replace(".", ""))
|
||||
|
|
Loading…
Reference in New Issue
Block a user