Best attempt at cutlass3.
This commit is contained in:
parent
84964db937
commit
0afc8e9e2f
8
Makefile
8
Makefile
|
@ -55,8 +55,8 @@ CC_cublasLt110 := -gencode arch=compute_75,code=sm_75
|
|||
CC_cublasLt110 += -gencode arch=compute_80,code=sm_80
|
||||
|
||||
CC_cublasLt111 := -gencode arch=compute_75,code=sm_75
|
||||
CC_cublasLt111 += -gencode arch=compute_80,code=sm_80
|
||||
CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
|
||||
#CC_cublasLt111 += -gencode arch=compute_80,code=sm_80
|
||||
#CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
|
||||
|
||||
CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89
|
||||
CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90
|
||||
|
@ -103,9 +103,9 @@ cuda11x: $(BUILD_DIR) env
|
|||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
|
||||
|
||||
cuda11x_cutlass: $(BUILD_DIR) env cutlass
|
||||
$(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(INCLUDE_cutlass) $(LIB) --output-directory $(BUILD_DIR)
|
||||
$(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math --expt-relaxed-constexpr -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(INCLUDE_cutlass) $(LIB) --output-directory $(BUILD_DIR)
|
||||
$(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++17 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
|
||||
$(GPP) -std=c++17 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(INCLUDE_cutlass) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
|
||||
|
||||
cuda12x: $(BUILD_DIR) env
|
||||
$(NVCC) $(CC_cublasLt111) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||
|
|
|
@ -1374,6 +1374,104 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
|
|||
|
||||
return sout
|
||||
|
||||
def cutlass3_gemm(
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
out: Tensor = None,
|
||||
transposed_A=False,
|
||||
transposed_B=False,
|
||||
):
|
||||
sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.float32)
|
||||
if out is None:
|
||||
out = torch.zeros(size=sout, dtype=torch.float32, device=A.device)
|
||||
|
||||
sA = A.shape
|
||||
sB = B.shape
|
||||
if transposed_A and len(sA) == 2:
|
||||
sA = (sA[1], sA[0])
|
||||
elif transposed_A and len(sA) == 3:
|
||||
sA = (sA[0], sA[2], sA[0])
|
||||
if transposed_B and len(sB) == 2:
|
||||
sB = (sB[1], sB[0])
|
||||
elif transposed_B and len(sB) == 3:
|
||||
sB = (sB[0], sB[2], sB[0])
|
||||
# this is a mess: cuBLAS expect column major, but PyTorch is row major.
|
||||
# So to perform the matrix multiplication, we have to treat A, B, and C matrices
|
||||
# (transpose of row major is column major)
|
||||
# This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these
|
||||
|
||||
# matrices in the input arguments for cuBLAS
|
||||
# column major: A @ B = C: [m, k] @ [k, n] = [m, n]
|
||||
# row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
|
||||
# column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
|
||||
if len(sB) == 2:
|
||||
if B.stride()[0] == B.shape[1]:
|
||||
transposed_B = False
|
||||
elif B.stride()[1] == B.shape[0]:
|
||||
transposed_B = True
|
||||
if len(A.shape) == 2:
|
||||
if A.stride()[0] == A.shape[1]:
|
||||
transposed_A = False
|
||||
elif A.stride()[1] == A.shape[0]:
|
||||
transposed_A = True
|
||||
else:
|
||||
if A.stride()[1] == A.shape[2]:
|
||||
transposed_A = False
|
||||
elif A.stride()[2] == A.shape[1]:
|
||||
transposed_A = True
|
||||
|
||||
if len(sA) == 2:
|
||||
n = sA[0]
|
||||
ldb = A.stride()[1 if transposed_A else 0]
|
||||
elif len(sA) == 3 and len(sB) == 2:
|
||||
n = sA[0] * sA[1]
|
||||
ldb = sA[2]
|
||||
|
||||
m = sB[1]
|
||||
k = sB[0]
|
||||
lda = B.stride()[(1 if transposed_B else 0)]
|
||||
ldc = sB[1]
|
||||
elif len(sB) == 3:
|
||||
# special case
|
||||
assert len(sA) == 3
|
||||
if not (sA[0] == sB[0] and sA[1] == sB[1]):
|
||||
raise ValueError(
|
||||
f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}"
|
||||
)
|
||||
|
||||
transposed_A = True
|
||||
transposed_B = False
|
||||
|
||||
m = sB[2]
|
||||
n = sA[2]
|
||||
k = sB[0] * sB[1]
|
||||
|
||||
lda = m
|
||||
ldb = sA[2]
|
||||
ldc = m
|
||||
|
||||
ptr = CUBLAS_Context.get_instance().get_context(A.device)
|
||||
|
||||
# B^T @ A^T = C^T
|
||||
# [km, nk -> mn]
|
||||
lda = ldb = ldc = 1
|
||||
#lda = 1
|
||||
print(m, n, k, lda, ldb, ldc)
|
||||
is_on_gpu([B, A, out])
|
||||
m = ct.c_int32(m)
|
||||
n = ct.c_int32(n)
|
||||
k = ct.c_int32(k)
|
||||
lda = ct.c_int32(lda)
|
||||
ldb = ct.c_int32(ldb)
|
||||
ldc = ct.c_int32(ldc)
|
||||
alpha = ct.c_float(1.0)
|
||||
beta = ct.c_float(0.0)
|
||||
lib.ccutlass_gemm(m, n, k, alpha, get_ptr(B), lda, get_ptr(A), ldb, beta, get_ptr(out), ldc)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
||||
|
||||
def igemm(
|
||||
A: Tensor,
|
||||
|
|
126
csrc/kernels.cu
126
csrc/kernels.cu
|
@ -19,7 +19,6 @@
|
|||
#include "cutlass/util/print_error.hpp"
|
||||
#include "cutlass/util/GPU_Clock.hpp"
|
||||
#include "cutlass/util/cublas_wrappers.hpp"
|
||||
#include "cutlass/util/helper_cuda.hpp"
|
||||
|
||||
#define HLF_MAX 65504
|
||||
#define TH 1024
|
||||
|
@ -2928,73 +2927,84 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
|
|||
}
|
||||
|
||||
|
||||
template <int QUANT_TYPE, typename INPT, typename COMPT, typename OUTT> __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB)
|
||||
{
|
||||
// element-wise kernel
|
||||
// 1. Load batch x k into registers
|
||||
// 2. Load k x k into registers
|
||||
// 3. dequantize and store in second pair of k x k
|
||||
// 4. matmul
|
||||
// 5. sum with cub
|
||||
// 6. store outputs
|
||||
// TC kernel
|
||||
// use k warps per thread block
|
||||
// 1. threadblock use read-only cache to read in register tile for A into shared memory
|
||||
// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments
|
||||
// 3. each warp reads a segment of values 16x32 from B
|
||||
// 4. do dequantization from register of B into second pair of registers
|
||||
// 5. store (4) into fragment
|
||||
// 6. matmul aggregate into fragment C
|
||||
// 7. aggreecate files of C into shared memroy block C
|
||||
// 8. sum (7)
|
||||
// 9. write outputs to matmul output matrix
|
||||
}
|
||||
//template <int QUANT_TYPE, typename INPT, typename COMPT, typename OUTT> __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB)
|
||||
//{
|
||||
//// element-wise kernel
|
||||
//// 1. Load batch x k into registers
|
||||
//// 2. Load k x k into registers
|
||||
//// 3. dequantize and store in second pair of k x k
|
||||
//// 4. matmul
|
||||
//// 5. sum with cub
|
||||
//// 6. store outputs
|
||||
//// TC kernel
|
||||
//// use k warps per thread block
|
||||
//// 1. threadblock use read-only cache to read in register tile for A into shared memory
|
||||
//// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments
|
||||
//// 3. each warp reads a segment of values 16x32 from B
|
||||
//// 4. do dequantization from register of B into second pair of registers
|
||||
//// 5. store (4) into fragment
|
||||
//// 6. matmul aggregate into fragment C
|
||||
//// 7. aggreecate files of C into shared memroy block C
|
||||
//// 8. sum (7)
|
||||
//// 9. write outputs to matmul output matrix
|
||||
//}
|
||||
|
||||
#include "cutlass/util/print_error.hpp"
|
||||
#include "cutlass/util/GPU_Clock.hpp"
|
||||
#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0
|
||||
# include "cutlass/util/cublas_wrappers.hpp"
|
||||
#endif
|
||||
#include "cutlass/util/helper_cuda.hpp"
|
||||
//#include "cutlass/util/helper_cuda.hpp"
|
||||
|
||||
template <class MShape, class NShape, class KShape,
|
||||
class TA, class AStride, class ABlockLayout, class AThreadLayout,
|
||||
class TB, class BStride, class BBlockLayout, class BThreadLayout,
|
||||
class TC, class CStride, class CBlockLayout, class CThreadLayout,
|
||||
class Alpha, class Beta>
|
||||
__global__ static
|
||||
__launch_bounds__(decltype(size(CThreadLayout{}))::value)
|
||||
void
|
||||
gemm_device(MShape M, NShape N, KShape K,
|
||||
TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA,
|
||||
TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
|
||||
TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
|
||||
Alpha alpha, Beta beta)
|
||||
__global__ void gemm_device(int M, int N, int K,
|
||||
float const* A,
|
||||
float const* B,
|
||||
float * out, int lda, int ldb, int ldc,
|
||||
float alpha, float beta)
|
||||
{
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
// Preconditions
|
||||
CUTE_STATIC_ASSERT(is_static<ABlockLayout>::value);
|
||||
CUTE_STATIC_ASSERT(is_static<BBlockLayout>::value);
|
||||
CUTE_STATIC_ASSERT(is_static<CBlockLayout>::value);
|
||||
//CUTE_STATIC_ASSERT(is_static<ABlockLayout>::value);
|
||||
//CUTE_STATIC_ASSERT(is_static<BBlockLayout>::value);
|
||||
//CUTE_STATIC_ASSERT(is_static<CBlockLayout>::value);
|
||||
|
||||
CUTE_STATIC_ASSERT(is_static<AThreadLayout>::value);
|
||||
CUTE_STATIC_ASSERT(is_static<BThreadLayout>::value);
|
||||
CUTE_STATIC_ASSERT(is_static<CThreadLayout>::value);
|
||||
//CUTE_STATIC_ASSERT(is_static<AThreadLayout>::value);
|
||||
//CUTE_STATIC_ASSERT(is_static<BThreadLayout>::value);
|
||||
//CUTE_STATIC_ASSERT(is_static<CThreadLayout>::value);
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size(tA) == size(tC));
|
||||
CUTE_STATIC_ASSERT_V(size(tB) == size(tC));
|
||||
//CUTE_STATIC_ASSERT_V(size(tA) == size(tC));
|
||||
//CUTE_STATIC_ASSERT_V(size(tB) == size(tC));
|
||||
|
||||
// Define block sizes (static)
|
||||
auto bM = Int<128>{};
|
||||
auto bN = Int<128>{};
|
||||
auto bK = Int< 8>{};
|
||||
|
||||
// Define the block layouts (static)
|
||||
auto bA = make_layout(make_shape(bM,bK));
|
||||
auto bB = make_layout(make_shape(bN,bK));
|
||||
auto bC = make_layout(make_shape(bM,bN));
|
||||
|
||||
// Define the thread layouts (static)
|
||||
auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{}));
|
||||
auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{}));
|
||||
auto tC = make_layout(make_shape(Int<16>{}, Int<16>{}));
|
||||
|
||||
//CUTE_STATIC_ASSERT_V(shape<0>(blockA) == shape<0>(blockC)); // BLK_M
|
||||
//CUTE_STATIC_ASSERT_V(shape<0>(blockB) == shape<1>(blockC)); // BLK_N
|
||||
CUTE_STATIC_ASSERT_V(shape<1>(blockA) == shape<1>(blockB)); // BLK_K
|
||||
//CUTE_STATIC_ASSERT_V(shape<1>(blockA) == shape<1>(blockB)); // BLK_K
|
||||
|
||||
// Shared memory buffers
|
||||
__shared__ TA smemA[cosize_v<ABlockLayout>];
|
||||
__shared__ TB smemB[cosize_v<BBlockLayout>];
|
||||
auto sA = make_tensor(make_smem_ptr(smemA), blockA); // (BLK_M,BLK_K)
|
||||
auto sB = make_tensor(make_smem_ptr(smemB), blockB); // (BLK_N,BLK_K)
|
||||
__shared__ float smemA[128*8];
|
||||
__shared__ float smemB[128*8];
|
||||
auto sA = make_tensor(make_smem_ptr(smemA), bA); // (BLK_M,BLK_K)
|
||||
auto sB = make_tensor(make_smem_ptr(smemB), bB); // (BLK_N,BLK_K)
|
||||
|
||||
auto dA = make_stride(Int<1>{}, lda);
|
||||
auto dB = make_stride(Int<1>{}, ldb);
|
||||
auto dC = make_stride(Int<1>{}, ldc);
|
||||
|
||||
// Represent the full tensors
|
||||
auto mA = make_tensor(make_gmem_ptr(A), make_shape(M,K), dA); // (M,K)
|
||||
|
@ -3083,11 +3093,27 @@ gemm_device(MShape M, NShape N, KShape K,
|
|||
}
|
||||
|
||||
|
||||
|
||||
//==============================================================
|
||||
// TEMPLATE DEFINITIONS
|
||||
//==============================================================
|
||||
|
||||
template __global__ void kMatmul_inference_4bit<NF4, half, half, half>(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB);
|
||||
//template <class MShape, class NShape, class KShape,
|
||||
// class TA, class AStride, class ABlockLayout, class AThreadLayout,
|
||||
// class TB, class BStride, class BBlockLayout, class BThreadLayout,
|
||||
// class TC, class CStride, class CBlockLayout, class CThreadLayout,
|
||||
// class Alpha, class Beta>
|
||||
//__global__ static
|
||||
//__launch_bounds__(decltype(size(CThreadLayout{}))::value)
|
||||
//void
|
||||
//gemm_device(MShape M, NShape N, KShape K,
|
||||
// TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA,
|
||||
// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
|
||||
// TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
|
||||
// half alpha, half beta);
|
||||
|
||||
|
||||
//template __global__ void kMatmul_inference_4bit<NF4, half, half, half>(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB);
|
||||
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
||||
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
#ifndef kernels
|
||||
#define kernels
|
||||
|
||||
template <int QUANT_TYPE, typename INP_TYPE, typename COMP_TYPE, typename OUT_TYPE>__global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB);
|
||||
//template <int QUANT_TYPE, typename INP_TYPE, typename COMP_TYPE, typename OUT_TYPE>__global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB);
|
||||
|
||||
template<typename T>__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n);
|
||||
|
||||
|
@ -122,4 +122,24 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
|
|||
|
||||
template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
||||
|
||||
//template <class MShape, class NShape, class KShape,
|
||||
// class TA, class AStride, class ABlockLayout, class AThreadLayout,
|
||||
// class TB, class BStride, class BBlockLayout, class BThreadLayout,
|
||||
// class TC, class CStride, class CBlockLayout, class CThreadLayout,
|
||||
// class Alpha, class Beta>
|
||||
//__global__ static
|
||||
//__launch_bounds__(decltype(size(CThreadLayout{}))::value)
|
||||
//void
|
||||
//gemm_device(MShape M, NShape N, KShape K,
|
||||
// TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA,
|
||||
// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
|
||||
// TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
|
||||
// Alpha alpha, Beta beta);
|
||||
|
||||
__global__ void gemm_device(int M, int N, int K,
|
||||
float const* A,
|
||||
float const* B,
|
||||
float * out, int lda, int ldb, int ldc,
|
||||
float alpha, float beta);
|
||||
|
||||
#endif
|
||||
|
|
73
csrc/ops.cu
73
csrc/ops.cu
|
@ -91,14 +91,12 @@ template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsign
|
|||
}
|
||||
|
||||
|
||||
void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB)
|
||||
{
|
||||
int num_blocks = (colsB+32-1)/32;
|
||||
kMatmul_inference_4bit<NF4, half, half, half><<<num_blocks, 256>>>(A, B, out, lda, ldb, rowsA, colsA, colsB);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template <int QUANT_TYPE, typename INP_TYPE, typename COMP_TYPE, typename OUT_TYPE>__global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *C, int lda, int ldb, int rowsA, int colsA, int colsB);
|
||||
//void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB)
|
||||
//{
|
||||
// int num_blocks = (colsB+32-1)/32;
|
||||
// kMatmul_inference_4bit<NF4, half, half, half><<<num_blocks, 256>>>(A, B, out, lda, ldb, rowsA, colsA, colsB);
|
||||
// CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
//}
|
||||
|
||||
|
||||
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
||||
|
@ -666,60 +664,47 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
|
|||
|
||||
|
||||
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include "cutlass/util/helper_cuda.hpp"
|
||||
|
||||
|
||||
template <typename TA, typename TB, typename TC,
|
||||
typename Alpha, typename Beta>
|
||||
void
|
||||
gemm(int m, int n, int k,
|
||||
Alpha alpha,
|
||||
TA const* A, int ldA,
|
||||
TB const* B, int ldB,
|
||||
Beta beta,
|
||||
TC * C, int ldC,
|
||||
cudaStream_t stream = 0)
|
||||
void gemm_host(int m, int n, int k,
|
||||
float alpha,
|
||||
float const* A, int lda,
|
||||
float const* B, int ldb,
|
||||
float beta,
|
||||
float * C, int ldc)
|
||||
{
|
||||
cute::device_init(0);
|
||||
using namespace cute;
|
||||
|
||||
|
||||
|
||||
// Define shapes (dynamic)
|
||||
auto M = int(m);
|
||||
auto N = int(n);
|
||||
auto K = int(k);
|
||||
|
||||
// Define strides (mixed)
|
||||
auto dA = make_stride(Int<1>{}, ldA);
|
||||
auto dB = make_stride(Int<1>{}, ldB);
|
||||
auto dC = make_stride(Int<1>{}, ldC);
|
||||
|
||||
// Define block sizes (static)
|
||||
auto bM = Int<128>{};
|
||||
auto bN = Int<128>{};
|
||||
auto bK = Int< 8>{};
|
||||
printf("%i %i %i %i %i %i\n", m, n, k, lda, ldb, ldc);
|
||||
|
||||
// Define the block layouts (static)
|
||||
auto sA = make_layout(make_shape(bM,bK));
|
||||
auto sB = make_layout(make_shape(bN,bK));
|
||||
auto sC = make_layout(make_shape(bM,bN));
|
||||
|
||||
// Define the thread layouts (static)
|
||||
auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{}));
|
||||
auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{}));
|
||||
auto tC = make_layout(make_shape(Int<16>{}, Int<16>{}));
|
||||
|
||||
dim3 dimBlock(size(tC));
|
||||
dim3 dimGrid(ceil_div(size(M), size(bM)),
|
||||
ceil_div(size(N), size(bN)));
|
||||
dim3 dimBlock(16, 16);
|
||||
dim3 dimGrid((M+127)/128, (N+127)/128);
|
||||
// auto tC = make_layout(make_shape(Int<16>{}, Int<16>{}));
|
||||
//-
|
||||
//- dim3 dimBlock(size(tC));
|
||||
//- dim3 dimGrid(ceil_div(size(M), size(bM)),
|
||||
//- ceil_div(size(N), size(bN)));
|
||||
gemm_device
|
||||
<<< dimGrid, dimBlock, 0, stream >>>
|
||||
<<< dimGrid, dimBlock, 0, 0 >>>
|
||||
(M, N, K,
|
||||
A, dA, sA, tA,
|
||||
B, dB, sB, tB,
|
||||
C, dC, sC, tC,
|
||||
A,
|
||||
B,
|
||||
C, lda, ldb, ldc,
|
||||
alpha, beta);
|
||||
}
|
||||
|
||||
|
||||
//==============================================================
|
||||
// TEMPLATE DEFINITIONS
|
||||
//==============================================================
|
||||
|
|
12
csrc/ops.cuh
12
csrc/ops.cuh
|
@ -20,6 +20,11 @@
|
|||
#include <vector>
|
||||
#include <functional>
|
||||
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
|
||||
|
||||
|
||||
#define CUDA_CHECK_RETURN(value) { \
|
||||
cudaError_t _m_cudaStat = value; \
|
||||
if (_m_cudaStat != cudaSuccess) { \
|
||||
|
@ -185,4 +190,11 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
|
|||
|
||||
void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB);
|
||||
|
||||
void gemm_host(int m, int n, int k,
|
||||
float alpha,
|
||||
float const* A, int ldA,
|
||||
float const* B, int ldB,
|
||||
float beta,
|
||||
float * C, int ldC);
|
||||
|
||||
#endif
|
||||
|
|
|
@ -20,6 +20,16 @@ void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimat
|
|||
void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles<half>(A, code, offset, n); }
|
||||
|
||||
|
||||
void
|
||||
cppgemm(int m, int n, int k,
|
||||
float alpha,
|
||||
float const* A, int ldA,
|
||||
float const* B, int ldB,
|
||||
float beta,
|
||||
float * C, int ldC)
|
||||
{ gemm_host(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);}
|
||||
|
||||
|
||||
#define MAKE_FUNC32(fname, oname, gtype, gbits) \
|
||||
void fname##32bit_g##gbits(gtype *g, gtype *p, \
|
||||
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
|
||||
|
@ -306,6 +316,14 @@ extern "C"
|
|||
void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); }
|
||||
void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); }
|
||||
|
||||
void ccutlass_gemm(int m, int n, int k,
|
||||
float alpha,
|
||||
float const* A, int ldA,
|
||||
float const* B, int ldB,
|
||||
float beta,
|
||||
float * C, int ldC)
|
||||
{ cppgemm(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);}
|
||||
|
||||
#endif
|
||||
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); }
|
||||
|
|
|
@ -2351,3 +2351,24 @@ def test_normal_map_tree():
|
|||
pivots.append((values[i-1]+values[i])/2)
|
||||
print(pivots)
|
||||
|
||||
|
||||
def test_cutlass3_gemm():
|
||||
#A = torch.rand(2, 2).cuda()
|
||||
#B = torch.rand(2, 2).cuda()
|
||||
A = torch.arange(4).reshape(2, 2).float().cuda().contiguous()
|
||||
B = torch.ones(2, 2).float().cuda()
|
||||
|
||||
print('')
|
||||
print(A)
|
||||
print(B)
|
||||
|
||||
C1 = torch.matmul(A, B)
|
||||
print(C1)
|
||||
C2 = F.cutlass3_gemm(A, B.t())
|
||||
print(C2)
|
||||
C2 = F.cutlass3_gemm(A, B)
|
||||
print(C2)
|
||||
C2 = F.cutlass3_gemm(B.t(), A.t().contiguous())
|
||||
print(C2)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user