Add HIP to cuda defines
collected by hipifying all files and then comparing with original Cuda file
This commit is contained in:
parent
18e827d666
commit
d10197bc93
|
@ -4,6 +4,23 @@
|
|||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
#include <kernels.cuh>
|
||||
|
||||
#ifdef BITS_AND_BYTES_USE_ROCM
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#include <hipcub/block/block_radix_sort.hpp>
|
||||
#include <hipcub/warp/warp_reduce.hpp>
|
||||
#include <hipcub/block/block_load.hpp>
|
||||
#include <hipcub/block/block_discontinuity.hpp>
|
||||
#include <hipcub/block/block_store.hpp>
|
||||
#include <hipcub/block/block_reduce.hpp>
|
||||
#include <hip/hip_math_constants.h>
|
||||
#define cub hipcub
|
||||
#define __syncwarp __syncthreads //TODO: HIP doesn't have this so just sync threads
|
||||
|
||||
#else
|
||||
#include <math_constants.h>
|
||||
#include <mma.h>
|
||||
#include <cub/block/block_radix_sort.cuh>
|
||||
#include <cub/warp/warp_reduce.cuh>
|
||||
#include <cub/block/block_load.cuh>
|
||||
|
@ -11,18 +28,17 @@
|
|||
#include <cub/block/block_store.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
#include <cub/cub.cuh>
|
||||
#include <math_constants.h>
|
||||
#endif
|
||||
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <mma.h>
|
||||
|
||||
|
||||
#define HLF_MAX 65504
|
||||
#define TH 1024
|
||||
#define NUM 4
|
||||
#define NUM_BLOCK 4096
|
||||
|
||||
|
||||
#ifndef BITS_AND_BYTES_USE_ROCM
|
||||
// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
|
||||
__device__ float atomicMax(float* address, float val) {
|
||||
int* address_as_i = reinterpret_cast<int*>(address);
|
||||
|
@ -47,6 +63,7 @@ __device__ float atomicMin(float* address, float val) {
|
|||
} while (assumed != old);
|
||||
return __int_as_float(old);
|
||||
}
|
||||
#endif
|
||||
|
||||
__device__ float dDequantizeFP4(unsigned char val, float absmax)
|
||||
{
|
||||
|
@ -723,6 +740,19 @@ template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TY
|
|||
//__launch_bounds__(TH, 4)
|
||||
__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n)
|
||||
{
|
||||
#ifdef BITS_AND_BYTES_USE_ROCM
|
||||
printf("kQuantizeBlockwise is not supported on Rocm!");
|
||||
|
||||
//TODO: figure out how to make compiler recognize what isn't executed based on template arguments, without the code below in ifndef would trigger static_assert if
|
||||
//this condition is true
|
||||
if ((BLOCK_SIZE / NUM_PER_TH % 64) != 0)
|
||||
{
|
||||
printf("kQuantizeBlockwise not fully supported on Rocm! BLOCK_SIZE/NUM_PER_TH needs to be divisible by 64.");
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef BITS_AND_BYTES_USE_ROCM
|
||||
const int n_full = gridDim.x * BLOCK_SIZE;
|
||||
int valid_items = 0;
|
||||
const int base_idx = (blockIdx.x * BLOCK_SIZE);
|
||||
|
@ -824,6 +854,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
|
|||
__syncthreads();
|
||||
StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template<typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
|
||||
|
@ -3956,6 +3987,7 @@ MAKE_optimizerStatic8bit2State(ADAM, float)
|
|||
template __global__ void kPercentileClipping<float, 2048, 4>(float * __restrict__ g, float *gnorm_vec, int step, const int n);
|
||||
template __global__ void kPercentileClipping<half, 2048, 4>(half * __restrict__ g, float *gnorm_vec, int step, const int n);
|
||||
|
||||
|
||||
#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \
|
||||
template __global__ void kQuantizeBlockwise<dtype, blocksize, num_per_thread, stochastic, data_type_name>(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \
|
||||
|
||||
|
|
|
@ -5,12 +5,17 @@
|
|||
|
||||
#include <ops.cuh>
|
||||
#include <kernels.cuh>
|
||||
#include <cub/device/device_scan.cuh>
|
||||
#include <limits>
|
||||
#include <BinSearch.h>
|
||||
#include <cassert>
|
||||
#include <common.h>
|
||||
|
||||
#ifdef BITS_AND_BYTES_USE_ROCM
|
||||
//#include <hipcub/device/device_scan.hpp>
|
||||
#else
|
||||
#include <cub/device/device_scan.cuh>
|
||||
#endif
|
||||
|
||||
|
||||
using namespace BinSearch;
|
||||
using std::cout;
|
||||
|
|
44
csrc/ops.cuh
44
csrc/ops.cuh
|
@ -12,16 +12,56 @@
|
|||
#include <unistd.h>
|
||||
#include <assert.h>
|
||||
|
||||
|
||||
#ifdef BITS_AND_BYTES_USE_ROCM
|
||||
#include <hip/hip_runtime_api.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hipblas/hipblas.h>
|
||||
#include <hipblaslt/hipblaslt.h> //only supports gfx903
|
||||
#include <hipsparse/hipsparse.h>
|
||||
|
||||
#define cudaPeekAtLastError hipPeekAtLastError
|
||||
#define cudaMemset hipMemset
|
||||
#define cublasGemmEx hipblasGemmEx
|
||||
#define cublasStatus_t hipblasStatus_t
|
||||
#define CUBLAS_OP_T HIPBLAS_OP_T
|
||||
#define CUBLAS_OP_N HIPBLAS_OP_N
|
||||
#define CUDA_R_8I HIPBLAS_R_8I
|
||||
#define CUDA_R_32I HIPBLAS_R_32I
|
||||
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
||||
#define cublasStatus_t hipblasStatus_t
|
||||
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
|
||||
#define cublasOperation_t hipblasOperation_t
|
||||
#define cublasLtMatrixLayoutCreate hipblasLtMatrixLayoutCreate
|
||||
#define cudaError_t hipError_t
|
||||
#define cudaGetErrorString hipGetErrorString
|
||||
#define cudaSuccess hipSuccess
|
||||
#define cusparseStatus_t hipsparseStatus_t
|
||||
#define CUSPARSE_STATUS_SUCCESS HIPSPARSE_STATUS_SUCCESS
|
||||
#define cublasStatus_t hipblasStatus_t
|
||||
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
||||
#define cublasHandle_t hipblasHandle_t
|
||||
#define cublasCreate_v2 hipblasCreate
|
||||
#define cusparseHandle_t hipsparseHandle_t
|
||||
#define cusparseCreate hipsparseCreate
|
||||
#define __nv_bfloat16 hip_bfloat16
|
||||
#define cublasLtHandle_t hipblasLtHandle_t
|
||||
#define cublasLtCreate hipblasLtCreate
|
||||
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT //TODO: HIP didn't have the right one, might cause issues
|
||||
|
||||
#else
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cublasLt.h>
|
||||
#include <cusparse.h>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#endif
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -5,9 +5,22 @@
|
|||
|
||||
#if BUILD_CUDA
|
||||
#include <ops.cuh>
|
||||
|
||||
#ifdef BITS_AND_BYTES_USE_ROCM
|
||||
#include <hipblaslt/hipblaslt.h>
|
||||
#define cublasLtHandle_t hipblasLtHandle_t
|
||||
#define cudaMemAttachHost hipMemAttachHost
|
||||
#define cudaMallocManaged hipMallocManaged
|
||||
#define cudaDevAttrConcurrentManagedAccess hipDeviceAttributeConcurrentManagedAccess
|
||||
#define cudaDeviceGetAttribute hipDeviceGetAttribute
|
||||
#define cudaMemPrefetchAsync hipMemPrefetchAsync
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#include <cpu_ops.h>
|
||||
|
||||
|
||||
|
||||
// We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary.
|
||||
// We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to
|
||||
// maintain all that boilerplate
|
||||
|
|
Loading…
Reference in New Issue
Block a user