From 1b52f4243f94cd1b81dd1cad5a9465d9d7add858 Mon Sep 17 00:00:00 2001 From: broncotc Date: Thu, 24 Nov 2022 05:15:08 +0000 Subject: [PATCH] fixed, works on gfx1030, do save RAM --- Makefile | 13 ++- bitsandbytes/__init__.py | 1 - bitsandbytes/__main__.py | 16 ---- bitsandbytes/cextension.py | 2 +- csrc/kernels.cu | 192 +++++++++++++++++++------------------ csrc/ops.cu | 42 ++++---- csrc/ops.cuh | 1 - csrc/pythonInterface.c | 2 +- 8 files changed, 139 insertions(+), 130 deletions(-) diff --git a/Makefile b/Makefile index 6c8a8c3..817310e 100644 --- a/Makefile +++ b/Makefile @@ -20,7 +20,7 @@ CSRC := $(ROOT_DIR)/csrc BUILD_DIR:= $(ROOT_DIR)/build FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu -FILES_HIP := $(CSRC)/ops.cu $(CSRC)/kernels.cu +# FILES_HIP := $(CSRC)/ops.cu $(CSRC)/kernels.cu FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include @@ -98,6 +98,17 @@ cuda11x: $(BUILD_DIR) env cpuonly: $(BUILD_DIR) env $(GPP) -std=c++14 -shared -fPIC -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cpu.so + +HIP_INCLUDE := -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include +# -I /opt/rocm-5.3.0/hipcub/include +HIP_LIB := -L/opt/rocm-5.3.0/lib -L/opt/rocm-5.3.0/llvm/bin/../lib/clang/15.0.0/lib/linux -L/usr/lib/gcc/x86_64-linux-gnu/11 -L/usr/lib/gcc/x86_64-linux-gnu/11/../../../../lib64 -L/lib/x86_64-linux-gnu -L/lib/../lib64 -L/usr/lib/x86_64-linux-gnu -L/usr/lib/../lib64 -L/lib -L/usr/lib -lgcc_s -lgcc -lpthread -lm -lrt -lamdhip64 -lhipblas -lhipsparse -lclang_rt.builtins-x86_64 -lstdc++ -lm -lgcc_s -lgcc -lc -lgcc_s -lgcc + +hip: $(BUILD_DIR) + /usr/bin/hipcc -std=c++14 -c -fPIC --amdgpu-target=gfx1030 $(HIP_INCLUDE) -o $(BUILD_DIR)/ops.o -D NO_CUBLASLT $(CSRC)/ops.cu + /usr/bin/hipcc -std=c++14 -c -fPIC --amdgpu-target=gfx1030 $(HIP_INCLUDE) -o $(BUILD_DIR)/kernels.o -D NO_CUBLASLT $(CSRC)/kernels.cu + # /usr/bin/hipcc -fPIC -static $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.so + $(GPP) -std=c++14 -D__HIP_PLATFORM_AMD__ -DBUILD_CUDA -shared -fPIC -I /opt/rocm/include $(HIP_INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(FILES_CPP) $(HIP_LIB) -o ./bitsandbytes/libbitsandbytes_hip_nocublaslt.so + env: @echo "ENVIRONMENT" @echo "============================" diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 6d1177f..b968204 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -12,7 +12,6 @@ from .autograd._functions import ( ) from .cextension import COMPILED_WITH_CUDA from .nn import modules -from . import cuda_setup, utils if COMPILED_WITH_CUDA: from .optim import adam diff --git a/bitsandbytes/__main__.py b/bitsandbytes/__main__.py index 175a30e..3ebf574 100644 --- a/bitsandbytes/__main__.py +++ b/bitsandbytes/__main__.py @@ -31,23 +31,7 @@ print() from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL -from .cuda_setup.main import get_compute_capabilities, get_cuda_lib_handle -from .cuda_setup.env_vars import to_be_ignored -print_header("POTENTIALLY LIBRARY-PATH-LIKE ENV VARS") -for k, v in os.environ.items(): - if "/" in v and not to_be_ignored(k, v): - print(f"'{k}': '{v}'") -print_header("") - -print( - "\nWARNING: Please be sure to sanitize sensible info from any such env vars!\n" -) - -print_header("OTHER") -print(f"{COMPILED_WITH_CUDA = }") -cuda = get_cuda_lib_handle() -print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities(cuda)}") print_header("") print_header("DEBUG INFO END") print_header("") diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 3f4273d..2b688c1 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -27,7 +27,7 @@ class CUDASetup(object): self.initialized = True self.cuda_setup_log = [] - binary_name = "libbitsandbytes_hip.so" + binary_name = "libbitsandbytes_hip_nocublaslt.so" package_dir = Path(__file__).parent binary_path = package_dir / binary_name diff --git a/csrc/kernels.cu b/csrc/kernels.cu index c29fa20..2bd7782 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -6,14 +6,14 @@ #include #include -#include -#include -#include -#include -#include -#include #include -#include +#include +#include +#include +#include +#include +#include +#include #define HLF_MAX 65504 #define TH 1024 @@ -21,29 +21,29 @@ #define NUM_BLOCK 4096 // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda -__device__ float atomicMax(float* address, float val) { - int* address_as_i = reinterpret_cast(address); - int old = *address_as_i, assumed; - do { - assumed = old; - old = atomicCAS( - reinterpret_cast(address), assumed, - __float_as_int(fmaxf(val, __int_as_float(assumed)))); - } while (assumed != old); - return __int_as_float(old); -} +// __device__ float atomicMax(float* address, float val) { +// int* address_as_i = reinterpret_cast(address); +// int old = *address_as_i, assumed; +// do { +// assumed = old; +// old = atomicCAS( +// reinterpret_cast(address), assumed, +// __float_as_int(fmaxf(val, __int_as_float(assumed)))); +// } while (assumed != old); +// return __int_as_float(old); +// } -__device__ float atomicMin(float* address, float val) { - int* address_as_i = reinterpret_cast(address); - int old = *address_as_i, assumed; - do { - assumed = old; - old = atomicCAS( - reinterpret_cast(address), assumed, - __float_as_int(fminf(val, __int_as_float(assumed)))); - } while (assumed != old); - return __int_as_float(old); -} +// __device__ float atomicMin(float* address, float val) { +// int* address_as_i = reinterpret_cast(address); +// int old = *address_as_i, assumed; +// do { +// assumed = old; +// old = atomicCAS( +// reinterpret_cast(address), assumed, +// __float_as_int(fminf(val, __int_as_float(assumed)))); +// } while (assumed != old); +// return __int_as_float(old); +// } template __device__ unsigned char dQuantize(float* smem_code, const float rand, float x) @@ -236,7 +236,7 @@ __global__ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* ou { typedef hipcub::WarpReduce WarpReduce; __shared__ typename WarpReduce::TempStorage temp_storage; - typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadT; __shared__ typename LoadT::TempStorage loadt; const int warp_idx = threadIdx.x/32; @@ -284,8 +284,8 @@ __global__ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* ou for(int i = 0; i < 8; i++) { // 3. do warp reduction + broadcast back - warp_max = WarpReduce(temp_storage).Reduce(max1, cub::Max()); - warp_max = cub::ShuffleIndex<32>(warp_max, 0, 0xffffffff); + warp_max = WarpReduce(temp_storage).Reduce(max1, hipcub::Max()); + warp_max = hipcub::ShuffleIndex<32>(warp_max, 0, 0xffffffff); // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest if(warp_max == max1) @@ -299,7 +299,9 @@ __global__ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* ou max2 = -64000.0f; } - __syncwarp(); + // __syncwarp(); + __syncthreads(); + } if(threadIdx.x % 32 < 8) @@ -326,8 +328,8 @@ __global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const f T vals[NUM_ESTIMATE]; - typedef hipcub::BlockRadixSort BlockRadixSort; - typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockRadixSort BlockRadixSort; + typedef hipcub::BlockLoad LoadFloat; __shared__ union { typename LoadFloat::TempStorage loadf; @@ -393,8 +395,8 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c unsigned char qvals[NUM]; //const int lane_id = threadIdx.x % 2; - typedef hipcub::BlockLoad LoadFloat; - typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockStore StoreChar; __shared__ typename LoadFloat::TempStorage loadf; __shared__ typename StoreChar::TempStorage storec; @@ -444,10 +446,10 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float float local_abs_max = 0.0f; int local_rand_idx = 0; - typedef hipcub::BlockLoad LoadT; - typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockStore StoreChar; typedef hipcub::BlockReduce BlockReduce; - typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockLoad LoadFloat; __shared__ typename LoadT::TempStorage loadt; __shared__ typename LoadFloat::TempStorage loadf; @@ -475,7 +477,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float for(int j = 0; j < NUM_PER_TH; j++) local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); - local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items); + local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, hipcub::Max(), valid_items); if(threadIdx.x == 0) smem_absmax_value[0] = local_abs_max; @@ -487,7 +489,9 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float else local_abs_max = smem_absmax_value[0]; - __syncwarp(); + // __syncwarp(); + __syncthreads(); + local_abs_max = 1.0f/local_abs_max; @@ -523,8 +527,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs unsigned char qvals[NUM_PER_TH]; float local_abs_max = -FLT_MAX; - typedef hipcub::BlockLoad LoadChar; - typedef hipcub::BlockStore StoreT; + typedef hipcub::BlockLoad LoadChar; + typedef hipcub::BlockStore StoreT; __shared__ typename LoadChar::TempStorage loadchar; __shared__ typename StoreT::TempStorage storet; @@ -594,8 +598,8 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p, const float correction1 = 1.0f/(1.0f - powf(beta1, step)); const float correction2 = 1.0f/(1.0f - powf(beta2, step)); - typedef hipcub::BlockLoad Load; - typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockLoad LoadFloat; typedef hipcub::BlockReduce BlockReduce; __shared__ union { @@ -645,7 +649,9 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p, if(threadIdx.x == 0) atomicAdd(&unorm[0], s1_vals[0]); - __syncwarp(); + // __syncwarp(); + __syncthreads(); + } } @@ -683,11 +689,11 @@ __global__ void kOptimizer32bit2State(T* g, T* p, } else{ update_scale = 1.0f; } - typedef hipcub::BlockLoad Load; - typedef hipcub::BlockStore Store; + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockStore Store; - typedef hipcub::BlockLoad LoadFloat; - typedef hipcub::BlockStore StoreFloat; + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockStore StoreFloat; __shared__ union { typename Load::TempStorage load; @@ -757,8 +763,8 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, float s1_vals[NUM_VALS]; - typedef hipcub::BlockLoad Load; - typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockLoad LoadFloat; typedef hipcub::BlockReduce BlockReduce; __shared__ union { @@ -815,7 +821,9 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, if(threadIdx.x == 0) atomicAdd(&unorm[0], s1_vals[0]); - __syncwarp(); + // __syncwarp(); + __syncthreads(); + } } @@ -845,11 +853,11 @@ __global__ void kOptimizer32bit1State(T *g, T *p, float s1_vals[NUM_PER_THREAD]; - typedef hipcub::BlockLoad Load; - typedef hipcub::BlockStore Store; + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockStore Store; - typedef hipcub::BlockLoad LoadFloat; - typedef hipcub::BlockStore StoreFloat; + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockStore StoreFloat; __shared__ union { typename Load::TempStorage load; @@ -941,8 +949,8 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c unsigned char m_c1[NUM8BIT]; unsigned char r_c2[NUM8BIT]; - typedef hipcub::BlockLoad LoadT; - typedef hipcub::BlockLoad LoadUInt8; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadUInt8; typedef hipcub::BlockReduce BlockReduce; @@ -1010,13 +1018,13 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c } __syncthreads(); - local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items); + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, hipcub::Max(), valid_items); __syncthreads(); - local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, cub::Max(), valid_items); + local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, hipcub::Max(), valid_items); if(unorm != NULL) { __syncthreads(); - local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items); + local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, hipcub::Sum(), valid_items); } if(threadIdx.x == 0) @@ -1070,11 +1078,11 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha unsigned char c2s[NUM_PER_THREAD2]; T p_vals[NUM_PER_THREAD2]; T g_vals[NUM_PER_THREAD2]; - typedef hipcub::BlockLoad LoadT; - typedef hipcub::BlockLoad LoadChar; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; - typedef hipcub::BlockStore StoreChar; - typedef hipcub::BlockStore StoreT; + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; __shared__ float smem_quantiles1[256]; __shared__ float smem_quantiles2[256]; @@ -1178,8 +1186,8 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c T g_vals[NUM8BIT]; unsigned char m_c1[NUM8BIT]; - typedef hipcub::BlockLoad LoadT; - typedef hipcub::BlockLoad LoadUInt8; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadUInt8; typedef hipcub::BlockReduce BlockReduce; @@ -1231,12 +1239,12 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c } __syncthreads(); - local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items); + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, hipcub::Max(), valid_items); if(threadIdx.x == 0){ atomicMax(&new_max1[0], local_max_s1); } if(unorm != NULL) { __syncthreads(); - local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items); + local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, hipcub::Sum(), valid_items); if(threadIdx.x == 0){ atomicAdd(&unorm[0], local_unorm); } } @@ -1273,11 +1281,11 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, unsigned char c1s[NUM_PER_THREAD2]; T p_vals[NUM_PER_THREAD2]; T g_vals[NUM_PER_THREAD2]; - typedef hipcub::BlockLoad LoadT; - typedef hipcub::BlockLoad LoadChar; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; - typedef hipcub::BlockStore StoreChar; - typedef hipcub::BlockStore StoreT; + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; __shared__ float smem_quantiles1[256]; @@ -1356,7 +1364,7 @@ __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int st int valid_items = 0; typedef hipcub::BlockReduce BlockReduce; - typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadT; __shared__ typename BlockReduce::TempStorage reduce; @@ -1428,11 +1436,11 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char unsigned char c1s[N_PER_TH]; unsigned char c2s[N_PER_TH]; T g_vals[N_PER_TH]; - typedef hipcub::BlockLoad LoadT; - typedef hipcub::BlockLoad LoadChar; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; - typedef hipcub::BlockStore StoreChar; - typedef hipcub::BlockStore StoreT; + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; __shared__ float smem_quantiles1[LANES][257]; __shared__ float smem_quantiles2[LANES][257]; @@ -1506,8 +1514,8 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char // reduce: 2.51/1.60 -> 2.67/1.69 - new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max()); - new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, cub::Max()); + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, hipcub::Max()); + new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, hipcub::Max()); if(threadIdx.x == 0) { @@ -1601,11 +1609,11 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char T g_vals[N_PER_TH]; T p_vals[N_PER_TH]; - typedef hipcub::BlockLoad LoadT; - typedef hipcub::BlockLoad LoadChar; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; - typedef hipcub::BlockStore StoreChar; - typedef hipcub::BlockStore StoreT; + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; __shared__ float smem_quantiles1[LANES][257]; typedef hipcub::BlockReduce BlockReduce1; @@ -1680,7 +1688,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char // reduce: 2.51/1.60 -> 2.67/1.69 - new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max()); + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, hipcub::Max()); if(threadIdx.x == 0) smem_exchange1[0] = new_local_abs_max1; @@ -1758,7 +1766,7 @@ template LoadT; + typedef hipcub::BlockLoad LoadT; typedef hipcub::BlockReduce BlockRowReduce; typedef hipcub::BlockReduce BlockRowSum; typedef hipcub::BlockExchange BlockExchange; @@ -1839,7 +1847,7 @@ template__global__ void kd float local_rowStats[ITEMS_PER_THREAD]; __shared__ float smem_rowStats[SUBTILE_ROWS]; - typedef hipcub::BlockLoad LoadInt32; + typedef hipcub::BlockLoad LoadInt32; typedef hipcub::BlockExchange ExchangeInt32; __shared__ typename LoadInt32::TempStorage loadint32; __shared__ typename ExchangeInt32::TempStorage exchangeint32; @@ -2035,9 +2043,9 @@ template LoadHalf; + typedef hipcub::BlockLoad LoadHalf; __shared__ typename LoadHalf::TempStorage loadhalf; - typedef hipcub::BlockStore StoreInt8; + typedef hipcub::BlockStore StoreInt8; __shared__ typename StoreInt8::TempStorage storeint8; __shared__ float smem_row_stats[TILE_ROWS]; diff --git a/csrc/ops.cu b/csrc/ops.cu index 9ce07be..5eaa5fb 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -7,7 +7,7 @@ #include #include "ops.cuh" #include "kernels.cuh" -#include +// #include #include // #include #include @@ -223,24 +223,32 @@ template void percentileClipping(T * g, float *gnorm_vec, int step, void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) { - const int falpha = 1; - const int fbeta = 0; - const void * alpha = &falpha; - const void * beta = &fbeta; - hipblasStatus_t status; + cout << "" << endl; + cout << "=============================================" << endl; + cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl; + cout << "=============================================" << endl; + cout << "" << endl; + assert(false); - status = hipblasGemmEx(context->m_handle, - transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, - transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, - m, n, k, - alpha, A, HIPBLAS_R_8I, lda, B, HIPBLAS_R_8I, ldb, beta, - C, HIPBLAS_R_32I, ldc, - HIPBLAS_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP); + return ; + // const int falpha = 1; + // const int fbeta = 0; + // const void * alpha = &falpha; + // const void * beta = &fbeta; + // hipblasStatus_t status; - if (status != HIPBLAS_STATUS_SUCCESS) - { - std::cout << "CUBLAS ERROR: Status " << status << std::endl; - } + // status = hipblasGemmEx(context->m_handle, + // transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, + // transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, + // m, n, k, + // alpha, A, HIPBLAS_R_8I, lda, B, HIPBLAS_R_8I, ldb, beta, + // C, HIPBLAS_R_32I, ldc, + // HIPBLAS_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP); + + // if (status != HIPBLAS_STATUS_SUCCESS) + // { + // std::cout << "CUBLAS ERROR: Status " << status << std::endl; + // } } diff --git a/csrc/ops.cuh b/csrc/ops.cuh index eb706d7..2ba039c 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -11,7 +11,6 @@ #include #include #include - #include #include #include diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 5bac30e..ae5e20c 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -275,7 +275,7 @@ extern "C" { transform_row2ampereT(A, out, rows, cols); } void cspmm_coo(ContextCusparse *context, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) - { spmm_coo((cusparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); } + { spmm_coo((hipsparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); } void cspmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) { spmm_coo_very_sparse_naive_fp16(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }