From 3809236428e704f9a7e22232701a651aafa5ca1b Mon Sep 17 00:00:00 2001
From: Titus von Koeller <titus@vonkoeller.com>
Date: Tue, 2 Aug 2022 07:42:27 -0700
Subject: [PATCH] move cuda_setup code into subpackage

---
 bitsandbytes/__init__.py                      |  1 +
 bitsandbytes/cextension.py                    |  2 +-
 bitsandbytes/cuda_setup/__init__.py           |  0
 bitsandbytes/cuda_setup/compute_capability.py | 65 +++++++++++++++++++
 .../{cuda_setup.py => cuda_setup/main.py}     |  4 +-
 5 files changed, 69 insertions(+), 3 deletions(-)
 create mode 100644 bitsandbytes/cuda_setup/__init__.py
 create mode 100644 bitsandbytes/cuda_setup/compute_capability.py
 rename bitsandbytes/{cuda_setup.py => cuda_setup/main.py} (98%)

diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py
index 6e5b6ac..76a5b48 100644
--- a/bitsandbytes/__init__.py
+++ b/bitsandbytes/__init__.py
@@ -12,6 +12,7 @@ from .autograd._functions import (
 )
 from .cextension import COMPILED_WITH_CUDA
 from .nn import modules
+from . import cuda_setup
 
 if COMPILED_WITH_CUDA:
     from .optim import adam
diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py
index bc11474..f5b97fd 100644
--- a/bitsandbytes/cextension.py
+++ b/bitsandbytes/cextension.py
@@ -2,7 +2,7 @@ import ctypes as ct
 import os
 from warnings import warn
 
-from bitsandbytes.cuda_setup import evaluate_cuda_setup
+from bitsandbytes.cuda_setup.main import evaluate_cuda_setup
 
 
 class CUDALibrary_Singleton(object):
diff --git a/bitsandbytes/cuda_setup/__init__.py b/bitsandbytes/cuda_setup/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/bitsandbytes/cuda_setup/compute_capability.py b/bitsandbytes/cuda_setup/compute_capability.py
new file mode 100644
index 0000000..19ceb3b
--- /dev/null
+++ b/bitsandbytes/cuda_setup/compute_capability.py
@@ -0,0 +1,65 @@
+import ctypes
+from dataclasses import dataclass, field
+
+
+CUDA_SUCCESS = 0
+
+@dataclass
+class CudaLibVals:
+    # code bits taken from 
+    # https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
+
+    nGpus = ctypes.c_int()
+    cc_major = ctypes.c_int()
+    cc_minor = ctypes.c_int()
+    device = ctypes.c_int()
+    error_str = ctypes.c_char_p()
+    cuda: ctypes.CDLL = field(init=False, repr=False)
+    ccs: List[str, ...] = field(init=False)
+
+    def load_cuda_lib(self):
+        """
+        1. find libcuda.so library (GPU driver) (/usr/lib)
+           init_device -> init variables -> call function by reference
+        """
+        libnames = ("libcuda.so")
+        for libname in libnames:
+            try:
+                self.cuda = ctypes.CDLL(libname)
+            except OSError:
+                continue
+            else:
+                break
+        else:
+            raise OSError("could not load any of: " + " ".join(libnames))
+
+    def check_cuda_result(self, result_val):
+        """
+        2. call extern C function to determine CC 
+           (see https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
+        """
+        cls_fields: Tuple[Field, ...] = fields(self.__class__)
+
+        if result_val != 0:
+            self.cuda.cuGetErrorString(result_val, ctypes.byref(self.error_str))
+            print("Count not initialize CUDA - failure!")
+            raise Exception("CUDA exception!")
+        return result_val
+
+    def __post_init__(self):
+        self.load_cuda_lib()
+        self.check_cuda_result(self.cuda.cuInit(0))
+        self.check_cuda_result(self.cuda, self.cuda.cuDeviceGetCount(ctypes.byref(self.nGpus)))
+        tmp_ccs = []
+        for gpu_index in range(self.nGpus.value):
+            check_cuda_result(
+                self.cuda, self.cuda.cuDeviceGet(ctypes.byref(self.device), gpu_index)
+            )
+            check_cuda_result(
+                self.cuda,
+                self.cuda.cuDeviceComputeCapability(
+                    ctypes.byref(self.cc_major), ctypes.byref(self.cc_minor), self.device
+                ),
+            )
+            tmp_ccs.append(f"{self.cc_major.value}.{self.cc_minor.value}")
+        self.ccs = sorted(tmp_ccs, reverse=True)
diff --git a/bitsandbytes/cuda_setup.py b/bitsandbytes/cuda_setup/main.py
similarity index 98%
rename from bitsandbytes/cuda_setup.py
rename to bitsandbytes/cuda_setup/main.py
index e68cd5e..6d70c92 100644
--- a/bitsandbytes/cuda_setup.py
+++ b/bitsandbytes/cuda_setup/main.py
@@ -1,6 +1,6 @@
 """
 extract factors the build is dependent on:
-[X] compute capability  
+[X] compute capability
     [ ] TODO: Q - What if we have multiple GPUs of different makes?
 - CUDA version
 - Software:
@@ -23,7 +23,7 @@ import os
 from pathlib import Path
 from typing import Set, Union
 
-from .utils import print_err, warn_of_missing_prerequisite, execute_and_return
+from ..utils import print_err, warn_of_missing_prerequisite, execute_and_return
 
 
 def check_cuda_result(cuda, result_val):