diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 66c79d8..e0f280a 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -17,12 +17,13 @@ class CUDALibrary_Singleton(object): binary_path = package_dir / binary_name if not binary_path.exists(): - print(f"TODO: compile library for specific version: {binary_name}") + print(f"CUDA_SETUP: TODO: compile library for specific version: {binary_name}") legacy_binary_name = "libbitsandbytes.so" - print(f"Defaulting to {legacy_binary_name}...") + print(f"CUDA_SETUP: Defaulting to {legacy_binary_name}...") self.lib = ct.cdll.LoadLibrary(package_dir / legacy_binary_name) else: - self.lib = ct.cdll.LoadLibrary(package_dir / binary_name) + print(f"CUDA_SETUP: Loading binary {binary_path}...") + self.lib = ct.cdll.LoadLibrary(binary_path) @classmethod def get_instance(cls): diff --git a/bitsandbytes/cuda_setup/__init__.py b/bitsandbytes/cuda_setup/__init__.py index e69de29..d8ebba8 100644 --- a/bitsandbytes/cuda_setup/__init__.py +++ b/bitsandbytes/cuda_setup/__init__.py @@ -0,0 +1,2 @@ +from .paths import CUDA_RUNTIME_LIB, extract_candidate_paths, determine_cuda_runtime_lib_path +from .main import evaluate_cuda_setup diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py index e96ac70..1e52f89 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/cuda_setup/main.py @@ -47,6 +47,7 @@ def get_compute_capabilities(): cuda = ctypes.CDLL("libcuda.so") except OSError: # TODO: shouldn't we error or at least warn here? + print('ERROR: libcuda.so not found!') return None nGpus = ctypes.c_int() @@ -70,7 +71,7 @@ def get_compute_capabilities(): ) ccs.append(f"{cc_major.value}.{cc_minor.value}") - return ccs.sort() + return ccs # def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error @@ -80,7 +81,8 @@ def get_compute_capability(): capabilities are downwards compatible. If no GPUs are detected, it returns None. """ - if ccs := get_compute_capabilities() is not None: + ccs = get_compute_capabilities() + if ccs is not None: # TODO: handle different compute capabilities; for now, take the max return ccs[-1] return None @@ -92,8 +94,7 @@ def evaluate_cuda_setup(): cc = get_compute_capability() binary_name = "libbitsandbytes_cpu.so" - # FIXME: has_gpu is still unused - if not (has_gpu := bool(cc)): + if cc == '': print( "WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..." ) @@ -115,6 +116,7 @@ def evaluate_cuda_setup(): ls_output.split(" ")[-1].replace("libcudart.so.", "").split(".") ) cuda_version_string = f"{major}{minor}" + print(f'CUDA_SETUP: Detected CUDA version {cuda_version_string}') def get_binary_name(): "if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so" @@ -122,6 +124,8 @@ def evaluate_cuda_setup(): if has_cublaslt: return f"{bin_base_name}{cuda_version_string}.so" else: - return f"{bin_base_name}_nocublaslt.so" + return f"{bin_base_name}{cuda_version_string}_nocublaslt.so" + + binary_name = get_binary_name() return binary_name diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 8ebe8c8..f1a15f5 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -351,9 +351,9 @@ def test_matmullt( err = torch.abs(out_bnb - out_torch).mean().item() # print(f'abs error {err:.4f}') idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) - assert (idx == 0).sum().item() < n * 0.0175 + assert (idx == 0).sum().item() <= n * 0.0175 idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) - assert (idx == 0).sum().item() < n * 0.001 + assert (idx == 0).sum().item() <= n * 0.001 if has_fp16_weights: if any(req_grad): @@ -391,9 +391,9 @@ def test_matmullt( assert torch.abs(gradB2).sum() == 0.0 idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) - assert (idx == 0).sum().item() < n * 0.1 + assert (idx == 0).sum().item() <= n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) - assert (idx == 0).sum().item() < n * 0.02 + assert (idx == 0).sum().item() <= n * 0.02 torch.testing.assert_allclose( gradB1, gradB2, atol=0.18, rtol=0.3 )