30 lines
791 B
Python
30 lines
791 B
Python
#!/usr/bin/env python3
|
|
import os
|
|
import torch
|
|
|
|
from setuptools import setup, find_packages
|
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
|
|
|
cxx_args = ['-std=c++11']
|
|
|
|
nvcc_args = [
|
|
'-gencode', 'arch=compute_50,code=sm_50',
|
|
'-gencode', 'arch=compute_52,code=sm_52',
|
|
'-gencode', 'arch=compute_60,code=sm_60',
|
|
'-gencode', 'arch=compute_61,code=sm_61',
|
|
'-gencode', 'arch=compute_70,code=sm_70',
|
|
'-gencode', 'arch=compute_70,code=compute_70'
|
|
]
|
|
|
|
setup(
|
|
name='correlation_cuda',
|
|
ext_modules=[
|
|
CUDAExtension('correlation_cuda', [
|
|
'correlation_cuda.cc',
|
|
'correlation_cuda_kernel.cu'
|
|
], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
|
|
],
|
|
cmdclass={
|
|
'build_ext': BuildExtension
|
|
})
|