2022-10-27 11:11:29 +00:00
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
2021-10-06 02:16:20 +00:00
// LICENSE file in the root directory of this source tree.
#ifndef ops_H
#define ops_H
#include <stdio.h>
#include <iostream>
#include <unistd.h>
#include <assert.h>
2023-08-05 00:11:46 +00:00
#ifdef BITS_AND_BYTES_USE_ROCM
2023-08-08 18:50:26 +00:00
#ifndef DISABLE_WARP_32
#define __AMDGCN_WAVEFRONT_SIZE 32 // check rocminfo | grep "Wavefront Size". Should be supported on all new GPU's
#endif
2023-08-05 00:11:46 +00:00
#include <hip/hip_runtime_api.h>
#include <hip/hip_fp16.h>
#include <hipblas/hipblas.h>
#include <hipsparse/hipsparse.h>
2023-10-13 00:16:14 +00:00
#include <hip/hip_bfloat16.h>
2023-08-05 00:11:46 +00:00
#define cudaPeekAtLastError hipPeekAtLastError
#define cudaMemset hipMemset
#define cublasGemmEx hipblasGemmEx
#define cublasStatus_t hipblasStatus_t
#define CUBLAS_OP_T HIPBLAS_OP_T
#define CUBLAS_OP_N HIPBLAS_OP_N
#define CUDA_R_8I HIPBLAS_R_8I
#define CUDA_R_32I HIPBLAS_R_32I
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define cublasStatus_t hipblasStatus_t
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
#define cublasOperation_t hipblasOperation_t
2023-10-13 00:16:14 +00:00
#define cublasLtMatrixLayoutCreate hipblasMatrixLayoutCreate
2023-08-05 00:11:46 +00:00
#define cudaError_t hipError_t
#define cudaGetErrorString hipGetErrorString
#define cudaSuccess hipSuccess
#define cusparseStatus_t hipsparseStatus_t
#define CUSPARSE_STATUS_SUCCESS HIPSPARSE_STATUS_SUCCESS
#define cublasStatus_t hipblasStatus_t
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define cublasHandle_t hipblasHandle_t
#define cublasCreate_v2 hipblasCreate
#define cusparseHandle_t hipsparseHandle_t
#define cusparseCreate hipsparseCreate
#define __nv_bfloat16 hip_bfloat16
2023-10-13 00:16:14 +00:00
#define cublasLtHandle_t hipblasHandle_t
#define cublasLtCreate hipblasCreate
2023-08-05 00:11:46 +00:00
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT //TODO: HIP didn't have the right one, might cause issues
#else
2021-10-06 02:16:20 +00:00
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
2022-07-22 21:41:05 +00:00
#include <cublas_v2.h>
#include <cublasLt.h>
#include <cusparse.h>
2021-10-06 02:16:20 +00:00
2023-04-27 00:12:34 +00:00
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
2023-08-05 00:11:46 +00:00
#endif
#include <vector>
#include <functional>
2023-04-27 00:12:34 +00:00
2021-10-06 02:16:20 +00:00
#define CUDA_CHECK_RETURN(value) { \
cudaError_t _m_cudaStat = value; \
if (_m_cudaStat != cudaSuccess) { \
fprintf(stderr, "Error %s at line %d in file %s\n", \
cudaGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
exit(1); \
} }
#define THREADS_PER_BLOCKS (512)
2022-07-22 21:41:05 +00:00
#define CHECK_CUSPARSE(value) { \
cusparseStatus_t _m_cudaStat = value; \
if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \
fprintf(stderr, "Error %s at line %d in file %s\n", \
cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
exit(1); \
} }
#define THREADS_PER_BLOCKS (512)
inline void checkCudaStatus(cudaError_t status) {
if (status != cudaSuccess) {
printf("cuda API failed with status %d: %s\n", status, cudaGetErrorString(status));
throw std::logic_error("cuda API failed");
}
}
inline int checkCublasStatus(cublasStatus_t status) {
if (status != CUBLAS_STATUS_SUCCESS) {
printf("cuBLAS API failed with status %d\n", status);
//throw std::logic_error("cuBLAS API failed");
return 1;
}
return 0;
}
2021-10-06 02:16:20 +00:00
typedef enum Operations_t
{
ksmul = 0,
} Operations_t;
typedef enum Optimizer_t
{
ADAM = 0,
MOMENTUM = 1,
RMSPROP = 2,
LARS = 3,
2021-11-10 23:10:02 +00:00
ADAGRAD = 4,
2023-03-09 18:10:19 +00:00
LION = 5,
2021-10-06 02:16:20 +00:00
} Optimizer_t;
2022-07-22 21:41:05 +00:00
typedef enum Transform_t
{
ROW = 0,
COL = 1,
COL32 = 2,
COL_TURING = 3,
COL_AMPERE = 4,
} Transform_t;
2023-04-02 23:10:35 +00:00
typedef enum DataType_t
{
General8bit = 0,
FP4 = 1,
NF4 = 2,
} DataType_t;
2023-05-06 18:14:06 +00:00
typedef enum Funcs_t
{
FILL = 0,
ARANGE = 1,
_MUL = 2,
} Funcs_t;
2022-07-22 21:41:05 +00:00
class Context
{
public:
cublasHandle_t m_handle;
Context()
{
cublasHandle_t handle;
cublasCreate_v2(&handle);
m_handle = handle;
}
};
class ContextLt
{
public:
cublasLtHandle_t m_handle;
ContextLt()
{
cublasLtHandle_t handle;
cublasLtCreate(&handle);
m_handle = handle;
}
};
class ContextCusparse
{
public:
cusparseHandle_t m_handle;
ContextCusparse()
{
cusparseHandle_t handle;
cusparseCreate(&handle);
m_handle = handle;
}
};
2021-10-06 02:16:20 +00:00
template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n);
void quantize(float *code, float *A, unsigned char *out, int n);
void dequantize(float *code, unsigned char *A, float *out, int n);
2023-04-02 23:10:35 +00:00
template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n);
2021-10-06 02:16:20 +00:00
2022-10-27 11:11:29 +00:00
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
2021-10-06 02:16:20 +00:00
float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
float beta1, float beta2, float eps, float weight_decay,
2021-10-21 01:37:44 +00:00
int step, float lr, const float gnorm_scale, bool skip_zeros, int n);
2021-10-06 02:16:20 +00:00
template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2,
float *unorm, float max_unorm, float param_norm,
float beta1, float beta2,
2022-10-27 11:11:29 +00:00
float eps, int step, float lr,
2021-10-06 02:16:20 +00:00
float* quantiles1, float* quantiles2,
float* max1, float* max2, float* new_max1, float* new_max2,
float weight_decay,
const float gnorm_scale, int n);
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g,
2022-10-27 11:11:29 +00:00
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale,
2021-10-21 01:37:44 +00:00
bool skip_zeros, int n);
2021-10-06 02:16:20 +00:00
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n);
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n);
2022-07-22 21:41:05 +00:00
void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
2022-10-27 11:11:29 +00:00
void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
2022-07-22 21:41:05 +00:00
long long int strideA, long long int strideB, long long int strideC, int batchCount);
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2);
void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
2022-08-16 17:56:17 +00:00
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols);
2022-07-22 21:41:05 +00:00
void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols);
void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed,
int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols);
template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols);
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B);
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
2022-07-26 19:12:38 +00:00
template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols);
2023-04-25 23:13:43 +00:00
void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB);
2023-04-29 05:10:42 +00:00
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits);
2023-04-30 04:52:47 +00:00
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
2023-07-10 04:06:01 +00:00
template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize);
2023-04-27 00:12:34 +00:00
2023-05-06 18:14:06 +00:00
template <typename T, int FUNC> void func(T *A, T *B, T value, long n);
2023-04-27 22:12:49 +00:00
2021-10-06 02:16:20 +00:00
#endif