2022-08-01 00:47:44 +00:00
|
|
|
import os
|
2022-08-01 10:31:48 +00:00
|
|
|
from typing import List, NamedTuple
|
|
|
|
|
|
|
|
import pytest
|
2022-07-28 04:16:04 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
from bitsandbytes.cuda_setup import (CUDA_RUNTIME_LIB, evaluate_cuda_setup,
|
|
|
|
get_cuda_runtime_lib_path, tokenize_paths)
|
2022-07-28 04:16:04 +00:00
|
|
|
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
class InputAndExpectedOutput(NamedTuple):
|
|
|
|
input: str
|
|
|
|
output: str
|
2022-07-28 04:16:04 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
HAPPY_PATH__LD_LIB_TEST_PATHS: List[InputAndExpectedOutput] = [
|
2022-07-28 04:16:04 +00:00
|
|
|
(f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"),
|
|
|
|
(f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"),
|
|
|
|
(f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:", f"dir/with/{CUDA_RUNTIME_LIB}"),
|
|
|
|
(f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"),
|
|
|
|
(f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir", f"dir/with/{CUDA_RUNTIME_LIB}"),
|
2022-08-01 10:31:48 +00:00
|
|
|
(
|
|
|
|
f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so",
|
|
|
|
f"dir/with/{CUDA_RUNTIME_LIB}",
|
|
|
|
),
|
2022-07-28 04:16:04 +00:00
|
|
|
]
|
|
|
|
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
@pytest.fixture(params=HAPPY_PATH__LD_LIB_TEST_PATHS)
|
|
|
|
def happy_path_path_string(tmpdir, request):
|
|
|
|
for path in tokenize_paths(request.param):
|
|
|
|
test_dir.mkdir()
|
|
|
|
if CUDA_RUNTIME_LIB in path:
|
|
|
|
(test_input / CUDA_RUNTIME_LIB).touch()
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("test_input, expected", HAPPY_PATH__LD_LIB_TEST_PATHS)
|
2022-07-28 04:16:04 +00:00
|
|
|
def test_get_cuda_runtime_lib_path__happy_path(
|
2022-08-01 10:31:48 +00:00
|
|
|
tmp_path, test_input: str, expected: str
|
2022-07-28 04:16:04 +00:00
|
|
|
):
|
|
|
|
for path in tokenize_paths(test_input):
|
2022-08-01 10:31:48 +00:00
|
|
|
path.mkdir()
|
|
|
|
(path / CUDA_RUNTIME_LIB).touch()
|
2022-07-28 04:16:04 +00:00
|
|
|
assert get_cuda_runtime_lib_path(test_input) == expected
|
|
|
|
|
|
|
|
|
|
|
|
UNHAPPY_PATH__LD_LIB_TEST_PATHS = [
|
|
|
|
f"a/b/c/{CUDA_RUNTIME_LIB}:d/e/f/{CUDA_RUNTIME_LIB}",
|
|
|
|
f"a/b/c/{CUDA_RUNTIME_LIB}:d/e/f/{CUDA_RUNTIME_LIB}:g/h/j/{CUDA_RUNTIME_LIB}",
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("test_input", UNHAPPY_PATH__LD_LIB_TEST_PATHS)
|
|
|
|
def test_get_cuda_runtime_lib_path__unhappy_path(tmp_path, test_input: str):
|
|
|
|
test_input = tmp_path / test_input
|
|
|
|
(test_input / CUDA_RUNTIME_LIB).touch()
|
|
|
|
with pytest.raises(FileNotFoundError) as err_info:
|
|
|
|
get_cuda_runtime_lib_path(test_input)
|
2022-08-01 10:31:48 +00:00
|
|
|
assert all(match in err_info for match in {"duplicate", CUDA_RUNTIME_LIB})
|
2022-07-28 04:16:04 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_get_cuda_runtime_lib_path__non_existent_dir(capsys, tmp_path):
|
2022-08-01 10:31:48 +00:00
|
|
|
existent_dir = tmp_path / "a/b"
|
2022-07-28 04:16:04 +00:00
|
|
|
existent_dir.mkdir()
|
2022-08-01 10:31:48 +00:00
|
|
|
non_existent_dir = tmp_path / "c/d" # non-existent dir
|
2022-07-28 04:16:04 +00:00
|
|
|
test_input = ":".join([str(existent_dir), str(non_existent_dir)])
|
|
|
|
|
|
|
|
get_cuda_runtime_lib_path(test_input)
|
|
|
|
std_err = capsys.readouterr().err
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
assert all(match in std_err for match in {"WARNING", "non-existent"})
|
|
|
|
|
2022-08-01 00:47:44 +00:00
|
|
|
|
|
|
|
def test_full_system():
|
|
|
|
## this only tests the cuda version and not compute capability
|
2022-08-01 10:31:48 +00:00
|
|
|
ld_path = os.environ["LD_LIBRARY_PATH"]
|
|
|
|
paths = ld_path.split(":")
|
|
|
|
version = ""
|
2022-08-01 00:47:44 +00:00
|
|
|
for p in paths:
|
2022-08-01 10:31:48 +00:00
|
|
|
if "cuda" in p:
|
|
|
|
idx = p.rfind("cuda-")
|
|
|
|
version = p[idx + 5 : idx + 5 + 4].replace("/", "")
|
2022-08-01 00:47:44 +00:00
|
|
|
version = float(version)
|
|
|
|
break
|
|
|
|
|
|
|
|
binary_name = evaluate_cuda_setup()
|
2022-08-01 10:31:48 +00:00
|
|
|
binary_name = binary_name.replace("libbitsandbytes_cuda", "")
|
|
|
|
assert binary_name.startswith(str(version).replace(".", ""))
|