Best attempt at cutlass3.

This commit is contained in:
Tim Dettmers 2023-04-26 17:12:34 -07:00
parent 84964db937
commit 0afc8e9e2f
8 changed files with 279 additions and 99 deletions

View File

@ -55,8 +55,8 @@ CC_cublasLt110 := -gencode arch=compute_75,code=sm_75
CC_cublasLt110 += -gencode arch=compute_80,code=sm_80
CC_cublasLt111 := -gencode arch=compute_75,code=sm_75
CC_cublasLt111 += -gencode arch=compute_80,code=sm_80
CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
#CC_cublasLt111 += -gencode arch=compute_80,code=sm_80
#CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89
CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90
@ -103,9 +103,9 @@ cuda11x: $(BUILD_DIR) env
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
cuda11x_cutlass: $(BUILD_DIR) env cutlass
$(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(INCLUDE_cutlass) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math --expt-relaxed-constexpr -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(INCLUDE_cutlass) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++17 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
$(GPP) -std=c++17 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(INCLUDE_cutlass) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
cuda12x: $(BUILD_DIR) env
$(NVCC) $(CC_cublasLt111) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)

View File

@ -1374,6 +1374,104 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
return sout
def cutlass3_gemm(
A: Tensor,
B: Tensor,
out: Tensor = None,
transposed_A=False,
transposed_B=False,
):
sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.float32)
if out is None:
out = torch.zeros(size=sout, dtype=torch.float32, device=A.device)
sA = A.shape
sB = B.shape
if transposed_A and len(sA) == 2:
sA = (sA[1], sA[0])
elif transposed_A and len(sA) == 3:
sA = (sA[0], sA[2], sA[0])
if transposed_B and len(sB) == 2:
sB = (sB[1], sB[0])
elif transposed_B and len(sB) == 3:
sB = (sB[0], sB[2], sB[0])
# this is a mess: cuBLAS expect column major, but PyTorch is row major.
# So to perform the matrix multiplication, we have to treat A, B, and C matrices
# (transpose of row major is column major)
# This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these
# matrices in the input arguments for cuBLAS
# column major: A @ B = C: [m, k] @ [k, n] = [m, n]
# row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
# column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
if len(sB) == 2:
if B.stride()[0] == B.shape[1]:
transposed_B = False
elif B.stride()[1] == B.shape[0]:
transposed_B = True
if len(A.shape) == 2:
if A.stride()[0] == A.shape[1]:
transposed_A = False
elif A.stride()[1] == A.shape[0]:
transposed_A = True
else:
if A.stride()[1] == A.shape[2]:
transposed_A = False
elif A.stride()[2] == A.shape[1]:
transposed_A = True
if len(sA) == 2:
n = sA[0]
ldb = A.stride()[1 if transposed_A else 0]
elif len(sA) == 3 and len(sB) == 2:
n = sA[0] * sA[1]
ldb = sA[2]
m = sB[1]
k = sB[0]
lda = B.stride()[(1 if transposed_B else 0)]
ldc = sB[1]
elif len(sB) == 3:
# special case
assert len(sA) == 3
if not (sA[0] == sB[0] and sA[1] == sB[1]):
raise ValueError(
f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}"
)
transposed_A = True
transposed_B = False
m = sB[2]
n = sA[2]
k = sB[0] * sB[1]
lda = m
ldb = sA[2]
ldc = m
ptr = CUBLAS_Context.get_instance().get_context(A.device)
# B^T @ A^T = C^T
# [km, nk -> mn]
lda = ldb = ldc = 1
#lda = 1
print(m, n, k, lda, ldb, ldc)
is_on_gpu([B, A, out])
m = ct.c_int32(m)
n = ct.c_int32(n)
k = ct.c_int32(k)
lda = ct.c_int32(lda)
ldb = ct.c_int32(ldb)
ldc = ct.c_int32(ldc)
alpha = ct.c_float(1.0)
beta = ct.c_float(0.0)
lib.ccutlass_gemm(m, n, k, alpha, get_ptr(B), lda, get_ptr(A), ldb, beta, get_ptr(out), ldc)
return out
def igemm(
A: Tensor,

View File

@ -19,7 +19,6 @@
#include "cutlass/util/print_error.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/cublas_wrappers.hpp"
#include "cutlass/util/helper_cuda.hpp"
#define HLF_MAX 65504
#define TH 1024
@ -2928,73 +2927,84 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
}
template <int QUANT_TYPE, typename INPT, typename COMPT, typename OUTT> __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB)
{
// element-wise kernel
// 1. Load batch x k into registers
// 2. Load k x k into registers
// 3. dequantize and store in second pair of k x k
// 4. matmul
// 5. sum with cub
// 6. store outputs
// TC kernel
// use k warps per thread block
// 1. threadblock use read-only cache to read in register tile for A into shared memory
// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments
// 3. each warp reads a segment of values 16x32 from B
// 4. do dequantization from register of B into second pair of registers
// 5. store (4) into fragment
// 6. matmul aggregate into fragment C
// 7. aggreecate files of C into shared memroy block C
// 8. sum (7)
// 9. write outputs to matmul output matrix
}
//template <int QUANT_TYPE, typename INPT, typename COMPT, typename OUTT> __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB)
//{
//// element-wise kernel
//// 1. Load batch x k into registers
//// 2. Load k x k into registers
//// 3. dequantize and store in second pair of k x k
//// 4. matmul
//// 5. sum with cub
//// 6. store outputs
//// TC kernel
//// use k warps per thread block
//// 1. threadblock use read-only cache to read in register tile for A into shared memory
//// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments
//// 3. each warp reads a segment of values 16x32 from B
//// 4. do dequantization from register of B into second pair of registers
//// 5. store (4) into fragment
//// 6. matmul aggregate into fragment C
//// 7. aggreecate files of C into shared memroy block C
//// 8. sum (7)
//// 9. write outputs to matmul output matrix
//}
#include "cutlass/util/print_error.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0
# include "cutlass/util/cublas_wrappers.hpp"
#endif
#include "cutlass/util/helper_cuda.hpp"
//#include "cutlass/util/helper_cuda.hpp"
template <class MShape, class NShape, class KShape,
class TA, class AStride, class ABlockLayout, class AThreadLayout,
class TB, class BStride, class BBlockLayout, class BThreadLayout,
class TC, class CStride, class CBlockLayout, class CThreadLayout,
class Alpha, class Beta>
__global__ static
__launch_bounds__(decltype(size(CThreadLayout{}))::value)
void
gemm_device(MShape M, NShape N, KShape K,
TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA,
TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
Alpha alpha, Beta beta)
__global__ void gemm_device(int M, int N, int K,
float const* A,
float const* B,
float * out, int lda, int ldb, int ldc,
float alpha, float beta)
{
using namespace cute;
using X = Underscore;
// Preconditions
CUTE_STATIC_ASSERT(is_static<ABlockLayout>::value);
CUTE_STATIC_ASSERT(is_static<BBlockLayout>::value);
CUTE_STATIC_ASSERT(is_static<CBlockLayout>::value);
//CUTE_STATIC_ASSERT(is_static<ABlockLayout>::value);
//CUTE_STATIC_ASSERT(is_static<BBlockLayout>::value);
//CUTE_STATIC_ASSERT(is_static<CBlockLayout>::value);
CUTE_STATIC_ASSERT(is_static<AThreadLayout>::value);
CUTE_STATIC_ASSERT(is_static<BThreadLayout>::value);
CUTE_STATIC_ASSERT(is_static<CThreadLayout>::value);
//CUTE_STATIC_ASSERT(is_static<AThreadLayout>::value);
//CUTE_STATIC_ASSERT(is_static<BThreadLayout>::value);
//CUTE_STATIC_ASSERT(is_static<CThreadLayout>::value);
CUTE_STATIC_ASSERT_V(size(tA) == size(tC));
CUTE_STATIC_ASSERT_V(size(tB) == size(tC));
//CUTE_STATIC_ASSERT_V(size(tA) == size(tC));
//CUTE_STATIC_ASSERT_V(size(tB) == size(tC));
// Define block sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 8>{};
// Define the block layouts (static)
auto bA = make_layout(make_shape(bM,bK));
auto bB = make_layout(make_shape(bN,bK));
auto bC = make_layout(make_shape(bM,bN));
// Define the thread layouts (static)
auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{}));
auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{}));
auto tC = make_layout(make_shape(Int<16>{}, Int<16>{}));
//CUTE_STATIC_ASSERT_V(shape<0>(blockA) == shape<0>(blockC)); // BLK_M
//CUTE_STATIC_ASSERT_V(shape<0>(blockB) == shape<1>(blockC)); // BLK_N
CUTE_STATIC_ASSERT_V(shape<1>(blockA) == shape<1>(blockB)); // BLK_K
//CUTE_STATIC_ASSERT_V(shape<1>(blockA) == shape<1>(blockB)); // BLK_K
// Shared memory buffers
__shared__ TA smemA[cosize_v<ABlockLayout>];
__shared__ TB smemB[cosize_v<BBlockLayout>];
auto sA = make_tensor(make_smem_ptr(smemA), blockA); // (BLK_M,BLK_K)
auto sB = make_tensor(make_smem_ptr(smemB), blockB); // (BLK_N,BLK_K)
__shared__ float smemA[128*8];
__shared__ float smemB[128*8];
auto sA = make_tensor(make_smem_ptr(smemA), bA); // (BLK_M,BLK_K)
auto sB = make_tensor(make_smem_ptr(smemB), bB); // (BLK_N,BLK_K)
auto dA = make_stride(Int<1>{}, lda);
auto dB = make_stride(Int<1>{}, ldb);
auto dC = make_stride(Int<1>{}, ldc);
// Represent the full tensors
auto mA = make_tensor(make_gmem_ptr(A), make_shape(M,K), dA); // (M,K)
@ -3083,11 +3093,27 @@ gemm_device(MShape M, NShape N, KShape K,
}
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
template __global__ void kMatmul_inference_4bit<NF4, half, half, half>(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB);
//template <class MShape, class NShape, class KShape,
// class TA, class AStride, class ABlockLayout, class AThreadLayout,
// class TB, class BStride, class BBlockLayout, class BThreadLayout,
// class TC, class CStride, class CBlockLayout, class CThreadLayout,
// class Alpha, class Beta>
//__global__ static
//__launch_bounds__(decltype(size(CThreadLayout{}))::value)
//void
//gemm_device(MShape M, NShape N, KShape K,
// TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA,
// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
// TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
// half alpha, half beta);
//template __global__ void kMatmul_inference_4bit<NF4, half, half, half>(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB);
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);

View File

@ -9,7 +9,7 @@
#ifndef kernels
#define kernels
template <int QUANT_TYPE, typename INP_TYPE, typename COMP_TYPE, typename OUT_TYPE>__global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB);
//template <int QUANT_TYPE, typename INP_TYPE, typename COMP_TYPE, typename OUT_TYPE>__global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB);
template<typename T>__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n);
@ -122,4 +122,24 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
//template <class MShape, class NShape, class KShape,
// class TA, class AStride, class ABlockLayout, class AThreadLayout,
// class TB, class BStride, class BBlockLayout, class BThreadLayout,
// class TC, class CStride, class CBlockLayout, class CThreadLayout,
// class Alpha, class Beta>
//__global__ static
//__launch_bounds__(decltype(size(CThreadLayout{}))::value)
//void
//gemm_device(MShape M, NShape N, KShape K,
// TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA,
// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
// TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
// Alpha alpha, Beta beta);
__global__ void gemm_device(int M, int N, int K,
float const* A,
float const* B,
float * out, int lda, int ldb, int ldc,
float alpha, float beta);
#endif

View File

@ -91,14 +91,12 @@ template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsign
}
void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB)
{
int num_blocks = (colsB+32-1)/32;
kMatmul_inference_4bit<NF4, half, half, half><<<num_blocks, 256>>>(A, B, out, lda, ldb, rowsA, colsA, colsB);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template <int QUANT_TYPE, typename INP_TYPE, typename COMP_TYPE, typename OUT_TYPE>__global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *C, int lda, int ldb, int rowsA, int colsA, int colsB);
//void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB)
//{
// int num_blocks = (colsB+32-1)/32;
// kMatmul_inference_4bit<NF4, half, half, half><<<num_blocks, 256>>>(A, B, out, lda, ldb, rowsA, colsA, colsB);
// CUDA_CHECK_RETURN(cudaPeekAtLastError());
//}
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
@ -666,60 +664,47 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
#include <cute/tensor.hpp>
#include "cutlass/util/helper_cuda.hpp"
template <typename TA, typename TB, typename TC,
typename Alpha, typename Beta>
void
gemm(int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
void gemm_host(int m, int n, int k,
float alpha,
float const* A, int lda,
float const* B, int ldb,
float beta,
float * C, int ldc)
{
cute::device_init(0);
using namespace cute;
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
// Define strides (mixed)
auto dA = make_stride(Int<1>{}, ldA);
auto dB = make_stride(Int<1>{}, ldB);
auto dC = make_stride(Int<1>{}, ldC);
// Define block sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 8>{};
printf("%i %i %i %i %i %i\n", m, n, k, lda, ldb, ldc);
// Define the block layouts (static)
auto sA = make_layout(make_shape(bM,bK));
auto sB = make_layout(make_shape(bN,bK));
auto sC = make_layout(make_shape(bM,bN));
// Define the thread layouts (static)
auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{}));
auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{}));
auto tC = make_layout(make_shape(Int<16>{}, Int<16>{}));
dim3 dimBlock(size(tC));
dim3 dimGrid(ceil_div(size(M), size(bM)),
ceil_div(size(N), size(bN)));
dim3 dimBlock(16, 16);
dim3 dimGrid((M+127)/128, (N+127)/128);
// auto tC = make_layout(make_shape(Int<16>{}, Int<16>{}));
//-
//- dim3 dimBlock(size(tC));
//- dim3 dimGrid(ceil_div(size(M), size(bM)),
//- ceil_div(size(N), size(bN)));
gemm_device
<<< dimGrid, dimBlock, 0, stream >>>
<<< dimGrid, dimBlock, 0, 0 >>>
(M, N, K,
A, dA, sA, tA,
B, dB, sB, tB,
C, dC, sC, tC,
A,
B,
C, lda, ldb, ldc,
alpha, beta);
}
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================

View File

@ -20,6 +20,11 @@
#include <vector>
#include <functional>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#define CUDA_CHECK_RETURN(value) { \
cudaError_t _m_cudaStat = value; \
if (_m_cudaStat != cudaSuccess) { \
@ -185,4 +190,11 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB);
void gemm_host(int m, int n, int k,
float alpha,
float const* A, int ldA,
float const* B, int ldB,
float beta,
float * C, int ldC);
#endif

View File

@ -20,6 +20,16 @@ void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimat
void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles<half>(A, code, offset, n); }
void
cppgemm(int m, int n, int k,
float alpha,
float const* A, int ldA,
float const* B, int ldB,
float beta,
float * C, int ldC)
{ gemm_host(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);}
#define MAKE_FUNC32(fname, oname, gtype, gbits) \
void fname##32bit_g##gbits(gtype *g, gtype *p, \
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
@ -306,6 +316,14 @@ extern "C"
void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); }
void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); }
void ccutlass_gemm(int m, int n, int k,
float alpha,
float const* A, int ldA,
float const* B, int ldB,
float beta,
float * C, int ldC)
{ cppgemm(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC);}
#endif
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); }

View File

@ -2351,3 +2351,24 @@ def test_normal_map_tree():
pivots.append((values[i-1]+values[i])/2)
print(pivots)
def test_cutlass3_gemm():
#A = torch.rand(2, 2).cuda()
#B = torch.rand(2, 2).cuda()
A = torch.arange(4).reshape(2, 2).float().cuda().contiguous()
B = torch.ones(2, 2).float().cuda()
print('')
print(A)
print(B)
C1 = torch.matmul(A, B)
print(C1)
C2 = F.cutlass3_gemm(A, B.t())
print(C2)
C2 = F.cutlass3_gemm(A, B)
print(C2)
C2 = F.cutlass3_gemm(B.t(), A.t().contiguous())
print(C2)