fix: Replace libcudart with pytorch api
This commit is contained in:
parent
659a7dfc71
commit
97b2567ada
|
@ -326,31 +326,10 @@ def get_cuda_lib_handle():
|
|||
|
||||
|
||||
def get_compute_capabilities(cuda):
|
||||
"""
|
||||
1. find libcuda.so library (GPU driver) (/usr/lib)
|
||||
init_device -> init variables -> call function by reference
|
||||
2. call extern C function to determine CC
|
||||
(https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
|
||||
3. Check for CUDA errors
|
||||
https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api
|
||||
# bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
|
||||
"""
|
||||
|
||||
nGpus = ct.c_int()
|
||||
cc_major = ct.c_int()
|
||||
cc_minor = ct.c_int()
|
||||
|
||||
device = ct.c_int()
|
||||
|
||||
check_cuda_result(cuda, cuda.cuDeviceGetCount(ct.byref(nGpus)))
|
||||
ccs = []
|
||||
for i in range(nGpus.value):
|
||||
check_cuda_result(cuda, cuda.cuDeviceGet(ct.byref(device), i))
|
||||
ref_major = ct.byref(cc_major)
|
||||
ref_minor = ct.byref(cc_minor)
|
||||
# 2. call extern C function to determine CC
|
||||
check_cuda_result(cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device))
|
||||
ccs.append(f"{cc_major.value}.{cc_minor.value}")
|
||||
# for i in range(torch.cuda.device_count()):
|
||||
# device = torch.cuda.device(i)
|
||||
ccs.append(torch.version.cuda)
|
||||
|
||||
return ccs
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user