import ctypes
from dataclasses import dataclass, field


@dataclass
class CudaLibVals:
    # code bits taken from
    # https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549

    nGpus: ctypes.c_int = field(default=ctypes.c_int())
    cc_major: ctypes.c_int = field(default=ctypes.c_int())
    cc_minor: ctypes.c_int = field(default=ctypes.c_int())
    device: ctypes.c_int = field(default=ctypes.c_int())
    error_str: ctypes.c_char_p = field(default=ctypes.c_char_p())
    cuda: ctypes.CDLL = field(init=False, repr=False)
    ccs: List[str, ...] = field(init=False)

    def _initialize_driver_API(self):
        self.check_cuda_result(self.cuda.cuInit(0))

    def _load_cuda_lib(self):
        """
        1. find libcuda.so library (GPU driver) (/usr/lib)
           init_device -> init variables -> call function by reference
        """
        libnames = "libcuda.so"
        for libname in libnames:
            try:
                self.cuda = ctypes.CDLL(libname)
            except OSError:
                continue
            else:
                break
        else:
            raise OSError("could not load any of: " + " ".join(libnames))

    def call_cuda_func(self, function_obj, **kwargs):
        CUDA_SUCCESS = 0  # constant taken from cuda.h
        pass
        # if (CUDA_SUCCESS := function_obj(

    def _error_handle(cuda_lib_call_return_value):
        """
        2. call extern C function to determine CC
           (see https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
        """
        CUDA_SUCCESS = 0  # constant taken from cuda.h

        if cuda_lib_call_return_value != CUDA_SUCCESS:
            self.cuda.cuGetErrorString(
                cuda_lib_call_return_value,
                ctypes.byref(self.error_str),
            )
            print("Count not initialize CUDA - failure!")
            raise Exception("CUDA exception!")
        return cuda_lib_call_return_value

    def __post_init__(self):
        self._load_cuda_lib()
        self._initialize_driver_API()
        self.check_cuda_result(
            self.cuda, self.cuda.cuDeviceGetCount(ctypes.byref(self.nGpus))
        )
        tmp_ccs = []
        for gpu_index in range(self.nGpus.value):
            check_cuda_result(
                self.cuda,
                self.cuda.cuDeviceGet(ctypes.byref(self.device), gpu_index),
            )
            check_cuda_result(
                self.cuda,
                self.cuda.cuDeviceComputeCapability(
                    ctypes.byref(self.cc_major),
                    ctypes.byref(self.cc_minor),
                    self.device,
                ),
            )
            tmp_ccs.append(f"{self.cc_major.value}.{self.cc_minor.value}")
        self.ccs = sorted(tmp_ccs, reverse=True)