|
|
|
@ -21,8 +21,8 @@
|
|
|
|
|
#include <hip/hip_runtime_api.h>
|
|
|
|
|
#include <hip/hip_fp16.h>
|
|
|
|
|
#include <hipblas/hipblas.h>
|
|
|
|
|
#include <hipblaslt/hipblaslt.h> //only supports gfx903
|
|
|
|
|
#include <hipsparse/hipsparse.h>
|
|
|
|
|
#include <hip/hip_bfloat16.h>
|
|
|
|
|
|
|
|
|
|
#define cudaPeekAtLastError hipPeekAtLastError
|
|
|
|
|
#define cudaMemset hipMemset
|
|
|
|
@ -36,7 +36,7 @@
|
|
|
|
|
#define cublasStatus_t hipblasStatus_t
|
|
|
|
|
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
|
|
|
|
|
#define cublasOperation_t hipblasOperation_t
|
|
|
|
|
#define cublasLtMatrixLayoutCreate hipblasLtMatrixLayoutCreate
|
|
|
|
|
#define cublasLtMatrixLayoutCreate hipblasMatrixLayoutCreate
|
|
|
|
|
#define cudaError_t hipError_t
|
|
|
|
|
#define cudaGetErrorString hipGetErrorString
|
|
|
|
|
#define cudaSuccess hipSuccess
|
|
|
|
@ -49,8 +49,8 @@
|
|
|
|
|
#define cusparseHandle_t hipsparseHandle_t
|
|
|
|
|
#define cusparseCreate hipsparseCreate
|
|
|
|
|
#define __nv_bfloat16 hip_bfloat16
|
|
|
|
|
#define cublasLtHandle_t hipblasLtHandle_t
|
|
|
|
|
#define cublasLtCreate hipblasLtCreate
|
|
|
|
|
#define cublasLtHandle_t hipblasHandle_t
|
|
|
|
|
#define cublasLtCreate hipblasCreate
|
|
|
|
|
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
|
|
|
|
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT //TODO: HIP didn't have the right one, might cause issues
|
|
|
|
|
|
|
|
|
|