From ac5550a0238286377ee3f58a85aeba1c40493e17 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 30 May 2023 19:06:59 -0700 Subject: [PATCH] Added changes for deployment. --- Makefile | 1 - csrc/kernels.cu | 10 +++++++--- deploy.sh | 11 ----------- 3 files changed, 7 insertions(+), 15 deletions(-) diff --git a/Makefile b/Makefile index 1f2b281..5fa1f17 100644 --- a/Makefile +++ b/Makefile @@ -33,7 +33,6 @@ COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta -COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta CC_KEPLER := -gencode arch=compute_35,code=sm_35 # Kepler CC_KEPLER += -gencode arch=compute_37,code=sm_37 # Kepler diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 11ad63f..ab12c37 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -16,15 +16,12 @@ #include #include -#include -#include #define HLF_MAX 65504 #define TH 1024 #define NUM 4 #define NUM_BLOCK 4096 -using namespace nvcuda; // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda __device__ float atomicMax(float* address, float val) { @@ -3094,6 +3091,9 @@ template __device__ inline void vector_l #define WARPS 5 template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) { + +#if __CUDA_ARCH__ >= 750 + using namespace nvcuda; int col_offset = blockIdx.x *32; const int warp_id = threadIdx.x / 32; const int half_warp_id = threadIdx.x / 16; @@ -3294,11 +3294,14 @@ template __global__ void gemm_device(int M, if(col_offset + warp_lane < M) out[col_offset + warp_lane] = smem_A[warp_lane]; +#endif } template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) { +#if __CUDA_ARCH__ >= 750 + using namespace nvcuda; int col_offset = blockIdx.x *32; const int warp_id = threadIdx.x / 32; const int half_warp_id = threadIdx.x / 16; @@ -3459,6 +3462,7 @@ template __global__ void kgemm_4bit_inference(int M, i if(col_offset + warp_lane < M) out[col_offset + warp_lane] = smem_A[warp_lane]; +#endif } //#define ROWS 2 diff --git a/deploy.sh b/deploy.sh index 24d6cbf..a2257a2 100644 --- a/deploy.sh +++ b/deploy.sh @@ -139,17 +139,6 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121.so" ]; then fi -make clean -export CUDA_HOME=$BASE_PATH/cuda-10.2 -make cuda10x_nomatmul CUDA_VERSION=102 - -if [ ! -f "./bitsandbytes/libbitsandbytes_cuda102_nocublaslt.so" ]; then - # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 - exit 64 -fi - - make clean export CUDA_HOME=$BASE_PATH/cuda-11.0 make cuda110_nomatmul CUDA_VERSION=110