Added the case that all env variables are empty (CUDA docker).

This commit is contained in:
Tim Dettmers 2022-08-05 08:57:52 -07:00
parent 6ad8796cfc
commit c472bd56f0
3 changed files with 10 additions and 3 deletions

View File

@ -117,10 +117,16 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
if env_var not in {"CONDA_PREFIX", "LD_LIBRARY_PATH"} if env_var not in {"CONDA_PREFIX", "LD_LIBRARY_PATH"}
} }
cuda_runtime_libs = set() cuda_runtime_libs = set()
for env_var, value in remaining_candidate_env_vars: for env_var, value in remaining_candidate_env_vars.items():
cuda_runtime_libs.update(find_cuda_lib_in(value)) cuda_runtime_libs.update(find_cuda_lib_in(value))
if len(cuda_runtime_libs) == 0:
print('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching /usr/local/cuda/lib64...')
cuda_runtime_libs.update(find_cuda_lib_in('/usr/local/cuda/lib64'))
warn_in_case_of_duplicates(cuda_runtime_libs) warn_in_case_of_duplicates(cuda_runtime_libs)
return next(iter(cuda_runtime_libs)) if cuda_runtime_libs else None return next(iter(cuda_runtime_libs)) if cuda_runtime_libs else None

View File

@ -18,7 +18,7 @@ def read(fname):
setup( setup(
name=f"bitsandbytes", name=f"bitsandbytes",
version=f"0.31.3", version=f"0.31.4",
author="Tim Dettmers", author="Tim Dettmers",
author_email="dettmers@cs.washington.edu", author_email="dettmers@cs.washington.edu",
description="8-bit optimizers and matrix multiplication routines.", description="8-bit optimizers and matrix multiplication routines.",

View File

@ -133,7 +133,7 @@ def test_full_system():
) )
version = float(f"{major}.{minor}") version = float(f"{major}.{minor}")
if version == "" and "LD_LIBRARY_PATH": if version == "" and "LD_LIBRARY_PATH" in os.environ:
ld_path = os.environ["LD_LIBRARY_PATH"] ld_path = os.environ["LD_LIBRARY_PATH"]
paths = ld_path.split(":") paths = ld_path.split(":")
version = "" version = ""
@ -144,6 +144,7 @@ def test_full_system():
version = float(version) version = float(version)
break break
assert version > 0 assert version > 0
binary_name = evaluate_cuda_setup() binary_name = evaluate_cuda_setup()
binary_name = binary_name.replace("libbitsandbytes_cuda", "") binary_name = binary_name.replace("libbitsandbytes_cuda", "")