diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 1ab8aa2..1cc7374 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -4,6 +4,23 @@ // LICENSE file in the root directory of this source tree. #include + +#ifdef BITS_AND_BYTES_USE_ROCM +#include +#include +#include +#include +#include +#include +#include +#include +#include +#define cub hipcub +#define __syncwarp __syncthreads //TODO: HIP doesn't have this so just sync threads + +#else +#include +#include #include #include #include @@ -11,18 +28,17 @@ #include #include #include -#include +#endif + #include #include -#include - #define HLF_MAX 65504 #define TH 1024 #define NUM 4 #define NUM_BLOCK 4096 - +#ifndef BITS_AND_BYTES_USE_ROCM // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda __device__ float atomicMax(float* address, float val) { int* address_as_i = reinterpret_cast(address); @@ -47,6 +63,7 @@ __device__ float atomicMin(float* address, float val) { } while (assumed != old); return __int_as_float(old); } +#endif __device__ float dDequantizeFP4(unsigned char val, float absmax) { @@ -723,6 +740,19 @@ template 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items); } + #endif } template @@ -3956,6 +3987,7 @@ MAKE_optimizerStatic8bit2State(ADAM, float) template __global__ void kPercentileClipping(float * __restrict__ g, float *gnorm_vec, int step, const int n); template __global__ void kPercentileClipping(half * __restrict__ g, float *gnorm_vec, int step, const int n); + #define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ template __global__ void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \ diff --git a/csrc/ops.cu b/csrc/ops.cu index 9776121..0606fd3 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -5,12 +5,17 @@ #include #include -#include #include #include #include #include +#ifdef BITS_AND_BYTES_USE_ROCM +//#include +#else +#include +#endif + using namespace BinSearch; using std::cout; diff --git a/csrc/ops.cuh b/csrc/ops.cuh index f37b3b3..a55ec24 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -12,16 +12,56 @@ #include #include + +#ifdef BITS_AND_BYTES_USE_ROCM +#include +#include +#include +#include //only supports gfx903 +#include + +#define cudaPeekAtLastError hipPeekAtLastError +#define cudaMemset hipMemset +#define cublasGemmEx hipblasGemmEx +#define cublasStatus_t hipblasStatus_t +#define CUBLAS_OP_T HIPBLAS_OP_T +#define CUBLAS_OP_N HIPBLAS_OP_N +#define CUDA_R_8I HIPBLAS_R_8I +#define CUDA_R_32I HIPBLAS_R_32I +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define cublasStatus_t hipblasStatus_t +#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx +#define cublasOperation_t hipblasOperation_t +#define cublasLtMatrixLayoutCreate hipblasLtMatrixLayoutCreate +#define cudaError_t hipError_t +#define cudaGetErrorString hipGetErrorString +#define cudaSuccess hipSuccess +#define cusparseStatus_t hipsparseStatus_t +#define CUSPARSE_STATUS_SUCCESS HIPSPARSE_STATUS_SUCCESS +#define cublasStatus_t hipblasStatus_t +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define cublasHandle_t hipblasHandle_t +#define cublasCreate_v2 hipblasCreate +#define cusparseHandle_t hipsparseHandle_t +#define cusparseCreate hipsparseCreate +#define __nv_bfloat16 hip_bfloat16 +#define cublasLtHandle_t hipblasLtHandle_t +#define cublasLtCreate hipblasLtCreate +#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT +#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT //TODO: HIP didn't have the right one, might cause issues + +#else #include #include #include #include #include -#include -#include #include #include +#endif +#include +#include diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 865e4b6..0a63a95 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -5,9 +5,22 @@ #if BUILD_CUDA #include + +#ifdef BITS_AND_BYTES_USE_ROCM +#include +#define cublasLtHandle_t hipblasLtHandle_t +#define cudaMemAttachHost hipMemAttachHost +#define cudaMallocManaged hipMallocManaged +#define cudaDevAttrConcurrentManagedAccess hipDeviceAttributeConcurrentManagedAccess +#define cudaDeviceGetAttribute hipDeviceGetAttribute +#define cudaMemPrefetchAsync hipMemPrefetchAsync +#endif + #endif #include + + // We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary. // We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to // maintain all that boilerplate