move cuda_setup code into subpackage
This commit is contained in:
parent
e120c4a550
commit
3809236428
|
@ -12,6 +12,7 @@ from .autograd._functions import (
|
||||||
)
|
)
|
||||||
from .cextension import COMPILED_WITH_CUDA
|
from .cextension import COMPILED_WITH_CUDA
|
||||||
from .nn import modules
|
from .nn import modules
|
||||||
|
from . import cuda_setup
|
||||||
|
|
||||||
if COMPILED_WITH_CUDA:
|
if COMPILED_WITH_CUDA:
|
||||||
from .optim import adam
|
from .optim import adam
|
||||||
|
|
|
@ -2,7 +2,7 @@ import ctypes as ct
|
||||||
import os
|
import os
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
from bitsandbytes.cuda_setup import evaluate_cuda_setup
|
from bitsandbytes.cuda_setup.main import evaluate_cuda_setup
|
||||||
|
|
||||||
|
|
||||||
class CUDALibrary_Singleton(object):
|
class CUDALibrary_Singleton(object):
|
||||||
|
|
0
bitsandbytes/cuda_setup/__init__.py
Normal file
0
bitsandbytes/cuda_setup/__init__.py
Normal file
65
bitsandbytes/cuda_setup/compute_capability.py
Normal file
65
bitsandbytes/cuda_setup/compute_capability.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
import ctypes
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
CUDA_SUCCESS = 0
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CudaLibVals:
|
||||||
|
# code bits taken from
|
||||||
|
# https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
|
||||||
|
|
||||||
|
nGpus = ctypes.c_int()
|
||||||
|
cc_major = ctypes.c_int()
|
||||||
|
cc_minor = ctypes.c_int()
|
||||||
|
device = ctypes.c_int()
|
||||||
|
error_str = ctypes.c_char_p()
|
||||||
|
cuda: ctypes.CDLL = field(init=False, repr=False)
|
||||||
|
ccs: List[str, ...] = field(init=False)
|
||||||
|
|
||||||
|
def load_cuda_lib(self):
|
||||||
|
"""
|
||||||
|
1. find libcuda.so library (GPU driver) (/usr/lib)
|
||||||
|
init_device -> init variables -> call function by reference
|
||||||
|
"""
|
||||||
|
libnames = ("libcuda.so")
|
||||||
|
for libname in libnames:
|
||||||
|
try:
|
||||||
|
self.cuda = ctypes.CDLL(libname)
|
||||||
|
except OSError:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise OSError("could not load any of: " + " ".join(libnames))
|
||||||
|
|
||||||
|
def check_cuda_result(self, result_val):
|
||||||
|
"""
|
||||||
|
2. call extern C function to determine CC
|
||||||
|
(see https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
|
||||||
|
"""
|
||||||
|
cls_fields: Tuple[Field, ...] = fields(self.__class__)
|
||||||
|
|
||||||
|
if result_val != 0:
|
||||||
|
self.cuda.cuGetErrorString(result_val, ctypes.byref(self.error_str))
|
||||||
|
print("Count not initialize CUDA - failure!")
|
||||||
|
raise Exception("CUDA exception!")
|
||||||
|
return result_val
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.load_cuda_lib()
|
||||||
|
self.check_cuda_result(self.cuda.cuInit(0))
|
||||||
|
self.check_cuda_result(self.cuda, self.cuda.cuDeviceGetCount(ctypes.byref(self.nGpus)))
|
||||||
|
tmp_ccs = []
|
||||||
|
for gpu_index in range(self.nGpus.value):
|
||||||
|
check_cuda_result(
|
||||||
|
self.cuda, self.cuda.cuDeviceGet(ctypes.byref(self.device), gpu_index)
|
||||||
|
)
|
||||||
|
check_cuda_result(
|
||||||
|
self.cuda,
|
||||||
|
self.cuda.cuDeviceComputeCapability(
|
||||||
|
ctypes.byref(self.cc_major), ctypes.byref(self.cc_minor), self.device
|
||||||
|
),
|
||||||
|
)
|
||||||
|
tmp_ccs.append(f"{self.cc_major.value}.{self.cc_minor.value}")
|
||||||
|
self.ccs = sorted(tmp_ccs, reverse=True)
|
|
@ -1,6 +1,6 @@
|
||||||
"""
|
"""
|
||||||
extract factors the build is dependent on:
|
extract factors the build is dependent on:
|
||||||
[X] compute capability
|
[X] compute capability
|
||||||
[ ] TODO: Q - What if we have multiple GPUs of different makes?
|
[ ] TODO: Q - What if we have multiple GPUs of different makes?
|
||||||
- CUDA version
|
- CUDA version
|
||||||
- Software:
|
- Software:
|
||||||
|
@ -23,7 +23,7 @@ import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Set, Union
|
from typing import Set, Union
|
||||||
|
|
||||||
from .utils import print_err, warn_of_missing_prerequisite, execute_and_return
|
from ..utils import print_err, warn_of_missing_prerequisite, execute_and_return
|
||||||
|
|
||||||
|
|
||||||
def check_cuda_result(cuda, result_val):
|
def check_cuda_result(cuda, result_val):
|
Loading…
Reference in New Issue
Block a user