fix: Replace libcudart with pytorch api
This commit is contained in:
parent
659a7dfc71
commit
97b2567ada
|
@ -326,31 +326,10 @@ def get_cuda_lib_handle():
|
||||||
|
|
||||||
|
|
||||||
def get_compute_capabilities(cuda):
|
def get_compute_capabilities(cuda):
|
||||||
"""
|
|
||||||
1. find libcuda.so library (GPU driver) (/usr/lib)
|
|
||||||
init_device -> init variables -> call function by reference
|
|
||||||
2. call extern C function to determine CC
|
|
||||||
(https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
|
|
||||||
3. Check for CUDA errors
|
|
||||||
https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api
|
|
||||||
# bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
|
|
||||||
"""
|
|
||||||
|
|
||||||
nGpus = ct.c_int()
|
|
||||||
cc_major = ct.c_int()
|
|
||||||
cc_minor = ct.c_int()
|
|
||||||
|
|
||||||
device = ct.c_int()
|
|
||||||
|
|
||||||
check_cuda_result(cuda, cuda.cuDeviceGetCount(ct.byref(nGpus)))
|
|
||||||
ccs = []
|
ccs = []
|
||||||
for i in range(nGpus.value):
|
# for i in range(torch.cuda.device_count()):
|
||||||
check_cuda_result(cuda, cuda.cuDeviceGet(ct.byref(device), i))
|
# device = torch.cuda.device(i)
|
||||||
ref_major = ct.byref(cc_major)
|
ccs.append(torch.version.cuda)
|
||||||
ref_minor = ct.byref(cc_minor)
|
|
||||||
# 2. call extern C function to determine CC
|
|
||||||
check_cuda_result(cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device))
|
|
||||||
ccs.append(f"{cc_major.value}.{cc_minor.value}")
|
|
||||||
|
|
||||||
return ccs
|
return ccs
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user