diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 55425ee..6036cff 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -21,8 +21,8 @@ #include #include #include -#include //only supports gfx903 #include +#include #define cudaPeekAtLastError hipPeekAtLastError #define cudaMemset hipMemset @@ -36,7 +36,7 @@ #define cublasStatus_t hipblasStatus_t #define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx #define cublasOperation_t hipblasOperation_t -#define cublasLtMatrixLayoutCreate hipblasLtMatrixLayoutCreate +#define cublasLtMatrixLayoutCreate hipblasMatrixLayoutCreate #define cudaError_t hipError_t #define cudaGetErrorString hipGetErrorString #define cudaSuccess hipSuccess @@ -49,8 +49,8 @@ #define cusparseHandle_t hipsparseHandle_t #define cusparseCreate hipsparseCreate #define __nv_bfloat16 hip_bfloat16 -#define cublasLtHandle_t hipblasLtHandle_t -#define cublasLtCreate hipblasLtCreate +#define cublasLtHandle_t hipblasHandle_t +#define cublasLtCreate hipblasCreate #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT //TODO: HIP didn't have the right one, might cause issues diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 0a63a95..19b474f 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -7,8 +7,8 @@ #include #ifdef BITS_AND_BYTES_USE_ROCM -#include -#define cublasLtHandle_t hipblasLtHandle_t +#include +#define cublasLtHandle_t hipblasHandle_t #define cudaMemAttachHost hipMemAttachHost #define cudaMallocManaged hipMallocManaged #define cudaDevAttrConcurrentManagedAccess hipDeviceAttributeConcurrentManagedAccess