Best attempt at cutlass3.

This commit is contained in:
Tim Dettmers 2023-04-26 17:12:34 -07:00
parent 84964db937
commit 0afc8e9e2f
8 changed files with 279 additions and 99 deletions

View File

@ -55,8 +55,8 @@ CC_cublasLt110 := -gencode arch=compute_75,code=sm_75
CC_cublasLt110 += -gencode arch=compute_80,code=sm_80 CC_cublasLt110 += -gencode arch=compute_80,code=sm_80
CC_cublasLt111 := -gencode arch=compute_75,code=sm_75 CC_cublasLt111 := -gencode arch=compute_75,code=sm_75
CC_cublasLt111 += -gencode arch=compute_80,code=sm_80 #CC_cublasLt111 += -gencode arch=compute_80,code=sm_80
CC_cublasLt111 += -gencode arch=compute_86,code=sm_86 #CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89 CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89
CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90 CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90
@ -103,9 +103,9 @@ cuda11x: $(BUILD_DIR) env
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
cuda11x_cutlass: $(BUILD_DIR) env cutlass cuda11x_cutlass: $(BUILD_DIR) env cutlass
$(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(INCLUDE_cutlass) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math --expt-relaxed-constexpr -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(INCLUDE_cutlass) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++17 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) $(GPP) -std=c++17 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(INCLUDE_cutlass) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
cuda12x: $(BUILD_DIR) env cuda12x: $(BUILD_DIR) env
$(NVCC) $(CC_cublasLt111) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(CC_cublasLt111) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)

View File

@ -1374,6 +1374,104 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
return sout return sout
def cutlass3_gemm(
A: Tensor,
B: Tensor,
out: Tensor = None,
transposed_A=False,
transposed_B=False,
):
sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.float32)
if out is None:
out = torch.zeros(size=sout, dtype=torch.float32, device=A.device)
sA = A.shape
sB = B.shape
if transposed_A and len(sA) == 2:
sA = (sA[1], sA[0])
elif transposed_A and len(sA) == 3:
sA = (sA[0], sA[2], sA[0])
if transposed_B and len(sB) == 2:
sB = (sB[1], sB[0])
elif transposed_B and len(sB) == 3:
sB = (sB[0], sB[2], sB[0])
# this is a mess: cuBLAS expect column major, but PyTorch is row major.
# So to perform the matrix multiplication, we have to treat A, B, and C matrices
# (transpose of row major is column major)
# This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these
# matrices in the input arguments for cuBLAS
# column major: A @ B = C: [m, k] @ [k, n] = [m, n]
# row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
# column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
if len(sB) == 2:
if B.stride()[0] == B.shape[1]:
transposed_B = False
elif B.stride()[1] == B.shape[0]:
transposed_B = True
if len(A.shape) == 2:
if A.stride()[0] == A.shape[1]:
transposed_A = False
elif A.stride()[1] == A.shape[0]:
transposed_A = True
else:
if A.stride()[1] == A.shape[2]:
transposed_A = False
elif A.stride()[2] == A.shape[1]:
transposed_A = True
if len(sA) == 2:
n = sA[0]
ldb = A.stride()[1 if transposed_A else 0]
elif len(sA) == 3 and len(sB) == 2:
n = sA[0] * sA[1]
ldb = sA[2]
m = sB[1]
k = sB[0]
lda = B.stride()[(1 if transposed_B else 0)]
ldc = sB[1]
elif len(sB) == 3:
# special case
assert len(sA) == 3
if not (sA[0] == sB[0] and sA[1] == sB[1]):
raise ValueError(
f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}"
)
transposed_A = True
transposed_B = False
m = sB[2]
n = sA[2]
k = sB[0] * sB[1]
lda = m
ldb = sA[2]
ldc = m
ptr = CUBLAS_Context.get_instance().get_context(A.device)
# B^T @ A^T = C^T
# [km, nk -> mn]
lda = ldb = ldc = 1
#lda = 1
print(m, n, k, lda, ldb, ldc)
is_on_gpu([B, A, out])
m = ct.c_int32(m)
n = ct.c_int32(n)
k = ct.c_int32(k)
lda = ct.c_int32(lda)
ldb = ct.c_int32(ldb)
ldc = ct.c_int32(ldc)
alpha = ct.c_float(1.0)
beta = ct.c_float(0.0)
lib.ccutlass_gemm(m, n, k, alpha, get_ptr(B), lda, get_ptr(A), ldb, beta, get_ptr(out), ldc)
return out
def igemm( def igemm(
A: Tensor, A: Tensor,

View File

@ -19,7 +19,6 @@
#include "cutlass/util/print_error.hpp" #include "cutlass/util/print_error.hpp"
#include "cutlass/util/GPU_Clock.hpp" #include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/cublas_wrappers.hpp" #include "cutlass/util/cublas_wrappers.hpp"
#include "cutlass/util/helper_cuda.hpp"
#define HLF_MAX 65504 #define HLF_MAX 65504
#define TH 1024 #define TH 1024
@ -2928,73 +2927,84 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
} }
template <int QUANT_TYPE, typename INPT, typename COMPT, typename OUTT> __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB) //template <int QUANT_TYPE, typename INPT, typename COMPT, typename OUTT> __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB)
{ //{
// element-wise kernel //// element-wise kernel
// 1. Load batch x k into registers //// 1. Load batch x k into registers
// 2. Load k x k into registers //// 2. Load k x k into registers
// 3. dequantize and store in second pair of k x k //// 3. dequantize and store in second pair of k x k
// 4. matmul //// 4. matmul
// 5. sum with cub //// 5. sum with cub
// 6. store outputs //// 6. store outputs
// TC kernel //// TC kernel
// use k warps per thread block //// use k warps per thread block
// 1. threadblock use read-only cache to read in register tile for A into shared memory //// 1. threadblock use read-only cache to read in register tile for A into shared memory
// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments //// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments
// 3. each warp reads a segment of values 16x32 from B //// 3. each warp reads a segment of values 16x32 from B
// 4. do dequantization from register of B into second pair of registers //// 4. do dequantization from register of B into second pair of registers
// 5. store (4) into fragment //// 5. store (4) into fragment
// 6. matmul aggregate into fragment C //// 6. matmul aggregate into fragment C
// 7. aggreecate files of C into shared memroy block C //// 7. aggreecate files of C into shared memroy block C
// 8. sum (7) //// 8. sum (7)
// 9. write outputs to matmul output matrix //// 9. write outputs to matmul output matrix
} //}
#include "cutlass/util/print_error.hpp" #include "cutlass/util/print_error.hpp"
#include "cutlass/util/GPU_Clock.hpp" #include "cutlass/util/GPU_Clock.hpp"
#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0 #if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0
# include "cutlass/util/cublas_wrappers.hpp" # include "cutlass/util/cublas_wrappers.hpp"
#endif #endif
#include "cutlass/util/helper_cuda.hpp" //#include "cutlass/util/helper_cuda.hpp"
template <class MShape, class NShape, class KShape, __global__ void gemm_device(int M, int N, int K,
class TA, class AStride, class ABlockLayout, class AThreadLayout, float const* A,
class TB, class BStride, class BBlockLayout, class BThreadLayout, float const* B,
class TC, class CStride, class CBlockLayout, class CThreadLayout, float * out, int lda, int ldb, int ldc,
class Alpha, class Beta> float alpha, float beta)
__global__ static
__launch_bounds__(decltype(size(CThreadLayout{}))::value)
void
gemm_device(MShape M, NShape N, KShape K,
TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA,
TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
Alpha alpha, Beta beta)
{ {
using namespace cute; using namespace cute;
using X = Underscore; using X = Underscore;
// Preconditions // Preconditions
CUTE_STATIC_ASSERT(is_static<ABlockLayout>::value); //CUTE_STATIC_ASSERT(is_static<ABlockLayout>::value);
CUTE_STATIC_ASSERT(is_static<BBlockLayout>::value); //CUTE_STATIC_ASSERT(is_static<BBlockLayout>::value);
CUTE_STATIC_ASSERT(is_static<CBlockLayout>::value); //CUTE_STATIC_ASSERT(is_static<CBlockLayout>::value);
CUTE_STATIC_ASSERT(is_static<AThreadLayout>::value); //CUTE_STATIC_ASSERT(is_static<AThreadLayout>::value);
CUTE_STATIC_ASSERT(is_static<BThreadLayout>::value); //CUTE_STATIC_ASSERT(is_static<BThreadLayout>::value);
CUTE_STATIC_ASSERT(is_static<CThreadLayout>::value); //CUTE_STATIC_ASSERT(is_static<CThreadLayout>::value);
CUTE_STATIC_ASSERT_V(size(tA) == size(tC)); //CUTE_STATIC_ASSERT_V(size(tA) == size(tC));
CUTE_STATIC_ASSERT_V(size(tB) == size(tC)); //CUTE_STATIC_ASSERT_V(size(tB) == size(tC));
// Define block sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 8>{};
// Define the block layouts (static)
auto bA = make_layout(make_shape(bM,bK));
auto bB = make_layout(make_shape(bN,bK));
auto bC = make_layout(make_shape(bM,bN));
// Define the thread layouts (static)
auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{}));
auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{}));
auto tC = make_layout(make_shape(Int<16>{}, Int<16>{}));
//CUTE_STATIC_ASSERT_V(shape<0>(blockA) == shape<0>(blockC)); // BLK_M //CUTE_STATIC_ASSERT_V(shape<0>(blockA) == shape<0>(blockC)); // BLK_M
//CUTE_STATIC_ASSERT_V(shape<0>(blockB) == shape<1>(blockC)); // BLK_N //CUTE_STATIC_ASSERT_V(shape<0>(blockB) == shape<1>(blockC)); // BLK_N
CUTE_STATIC_ASSERT_V(shape<1>(blockA) == shape<1>(blockB)); // BLK_K //CUTE_STATIC_ASSERT_V(shape<1>(blockA) == shape<1>(blockB)); // BLK_K
// Shared memory buffers // Shared memory buffers
__shared__ TA smemA[cosize_v<ABlockLayout>]; __shared__ float smemA[128*8];
__shared__ TB smemB[cosize_v<BBlockLayout>]; __shared__ float smemB[128*8];
auto sA = make_tensor(make_smem_ptr(smemA), blockA); // (BLK_M,BLK_K) auto sA = make_tensor(make_smem_ptr(smemA), bA); // (BLK_M,BLK_K)
auto sB = make_tensor(make_smem_ptr(smemB), blockB); // (BLK_N,BLK_K) auto sB = make_tensor(make_smem_ptr(smemB), bB); // (BLK_N,BLK_K)
auto dA = make_stride(Int<1>{}, lda);
auto dB = make_stride(Int<1>{}, ldb);
auto dC = make_stride(Int<1>{}, ldc);
// Represent the full tensors // Represent the full tensors
auto mA = make_tensor(make_gmem_ptr(A), make_shape(M,K), dA); // (M,K) auto mA = make_tensor(make_gmem_ptr(A), make_shape(M,K), dA); // (M,K)
@ -3083,11 +3093,27 @@ gemm_device(MShape M, NShape N, KShape K,
} }
//============================================================== //==============================================================
// TEMPLATE DEFINITIONS // TEMPLATE DEFINITIONS
//============================================================== //==============================================================
template __global__ void kMatmul_inference_4bit<NF4, half, half, half>(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB); //template <class MShape, class NShape, class KShape,
// class TA, class AStride, class ABlockLayout, class AThreadLayout,
// class TB, class BStride, class BBlockLayout, class BThreadLayout,
// class TC, class CStride, class CBlockLayout, class CThreadLayout,
// class Alpha, class Beta>
//__global__ static
//__launch_bounds__(decltype(size(CThreadLayout{}))::value)
//void
//gemm_device(MShape M, NShape N, KShape K,
// TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA,
// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
// TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
// half alpha, half beta);
//template __global__ void kMatmul_inference_4bit<NF4, half, half, half>(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB);
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);

View File

@ -9,7 +9,7 @@
#ifndef kernels #ifndef kernels
#define kernels #define kernels
template <int QUANT_TYPE, typename INP_TYPE, typename COMP_TYPE, typename OUT_TYPE>__global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB); //template <int QUANT_TYPE, typename INP_TYPE, typename COMP_TYPE, typename OUT_TYPE>__global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB);
template<typename T>__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n); template<typename T>__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n);
@ -122,4 +122,24 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
//template <class MShape, class NShape, class KShape,
// class TA, class AStride, class ABlockLayout, class AThreadLayout,
// class TB, class BStride, class BBlockLayout, class BThreadLayout,
// class TC, class CStride, class CBlockLayout, class CThreadLayout,
// class Alpha, class Beta>
//__global__ static
//__launch_bounds__(decltype(size(CThreadLayout{}))::value)
//void
//gemm_device(MShape M, NShape N, KShape K,
// TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA,
// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
// TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
// Alpha alpha, Beta beta);
__global__ void gemm_device(int M, int N, int K,
float const* A,
float const* B,
float * out, int lda, int ldb, int ldc,
float alpha, float beta);
#endif #endif

View File

@ -91,14 +91,12 @@ template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsign
} }
void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB) //void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB)
{ //{
int num_blocks = (colsB+32-1)/32; // int num_blocks = (colsB+32-1)/32;
kMatmul_inference_4bit<NF4, half, half, half><<<num_blocks, 256>>>(A, B, out, lda, ldb, rowsA, colsA, colsB); // kMatmul_inference_4bit<NF4, half, half, half><<<num_blocks, 256>>>(A, B, out, lda, ldb, rowsA, colsA, colsB);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); // CUDA_CHECK_RETURN(cudaPeekAtLastError());
} //}
template <int QUANT_TYPE, typename INP_TYPE, typename COMP_TYPE, typename OUT_TYPE>__global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *C, int lda, int ldb, int rowsA, int colsA, int colsB);
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
@ -666,60 +664,47 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
#include <cute/tensor.hpp> #include <cute/tensor.hpp>
#include "cutlass/util/helper_cuda.hpp"
template <typename TA, typename TB, typename TC, void gemm_host(int m, int n, int k,
typename Alpha, typename Beta> float alpha,
void float const* A, int lda,
gemm(int m, int n, int k, float const* B, int ldb,
Alpha alpha, float beta,
TA const* A, int ldA, float * C, int ldc)
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{ {
cute::device_init(0);
using namespace cute; using namespace cute;
// Define shapes (dynamic) // Define shapes (dynamic)
auto M = int(m); auto M = int(m);
auto N = int(n); auto N = int(n);
auto K = int(k); auto K = int(k);
// Define strides (mixed)
auto dA = make_stride(Int<1>{}, ldA);
auto dB = make_stride(Int<1>{}, ldB);
auto dC = make_stride(Int<1>{}, ldC);
// Define block sizes (static) printf("%i %i %i %i %i %i\n", m, n, k, lda, ldb, ldc);
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 8>{};
// Define the block layouts (static) dim3 dimBlock(16, 16);
auto sA = make_layout(make_shape(bM,bK)); dim3 dimGrid((M+127)/128, (N+127)/128);
auto sB = make_layout(make_shape(bN,bK)); // auto tC = make_layout(make_shape(Int<16>{}, Int<16>{}));
auto sC = make_layout(make_shape(bM,bN)); //-
//- dim3 dimBlock(size(tC));
// Define the thread layouts (static) //- dim3 dimGrid(ceil_div(size(M), size(bM)),
auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{})); //- ceil_div(size(N), size(bN)));
auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{}));
auto tC = make_layout(make_shape(Int<16>{}, Int<16>{}));
dim3 dimBlock(size(tC));
dim3 dimGrid(ceil_div(size(M), size(bM)),
ceil_div(size(N), size(bN)));
gemm_device gemm_device
<<< dimGrid, dimBlock, 0, stream >>> <<< dimGrid, dimBlock, 0, 0 >>>
(M, N, K, (M, N, K,
A, dA, sA, tA, A,
B, dB, sB, tB, B,
C, dC, sC, tC, C, lda, ldb, ldc,
alpha, beta); alpha, beta);
} }
//============================================================== //==============================================================
// TEMPLATE DEFINITIONS // TEMPLATE DEFINITIONS
//============================================================== //==============================================================

View File

@ -20,6 +20,11 @@
#include <vector> #include <vector>
#include <functional> #include <functional>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#define CUDA_CHECK_RETURN(value) { \ #define CUDA_CHECK_RETURN(value) { \
cudaError_t _m_cudaStat = value; \ cudaError_t _m_cudaStat = value; \
if (_m_cudaStat != cudaSuccess) { \ if (_m_cudaStat != cudaSuccess) { \
@ -185,4 +190,11 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB);
void gemm_host(int m, int n, int k,
float alpha,
float const* A, int ldA,
float const* B, int ldB,
float beta,
float * C, int ldC);
#endif #endif

View File

@ -20,6 +20,16 @@ void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimat
void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles<half>(A, code, offset, n); } void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles<half>(A, code, offset, n); }
void
cppgemm(int m, int n, int k,
float alpha,
float const* A, int ldA,
float const* B, int ldB,
float beta,
float * C, int ldC)
{ gemm_host(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);}
#define MAKE_FUNC32(fname, oname, gtype, gbits) \ #define MAKE_FUNC32(fname, oname, gtype, gbits) \
void fname##32bit_g##gbits(gtype *g, gtype *p, \ void fname##32bit_g##gbits(gtype *g, gtype *p, \
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
@ -306,6 +316,14 @@ extern "C"
void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); } void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); }
void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); } void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); }
void ccutlass_gemm(int m, int n, int k,
float alpha,
float const* A, int ldA,
float const* B, int ldB,
float beta,
float * C, int ldC)
{ cppgemm(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);}
#endif #endif
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); }

View File

@ -2351,3 +2351,24 @@ def test_normal_map_tree():
pivots.append((values[i-1]+values[i])/2) pivots.append((values[i-1]+values[i])/2)
print(pivots) print(pivots)
def test_cutlass3_gemm():
#A = torch.rand(2, 2).cuda()
#B = torch.rand(2, 2).cuda()
A = torch.arange(4).reshape(2, 2).float().cuda().contiguous()
B = torch.ones(2, 2).float().cuda()
print('')
print(A)
print(B)
C1 = torch.matmul(A, B)
print(C1)
C2 = F.cutlass3_gemm(A, B.t())
print(C2)
C2 = F.cutlass3_gemm(A, B)
print(C2)
C2 = F.cutlass3_gemm(B.t(), A.t().contiguous())
print(C2)