From 6e2544da251ccf281d5d88611d2cb5c13bcf42a6 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 25 Apr 2023 16:15:44 -0700 Subject: [PATCH] Added cutlass example. --- csrc/kernels.cu | 134 ++++++++++++++++++++++++++++++++++++++++++++++++ csrc/ops.cu | 57 ++++++++++++++++++++ 2 files changed, 191 insertions(+) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 5d2a58e..a108772 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2942,6 +2942,140 @@ template __global // 9. write outputs to matmul output matrix } +#include "cutlass/util/print_error.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0 +# include "cutlass/util/cublas_wrappers.hpp" +#endif +#include "cutlass/util/helper_cuda.hpp" + +template +__global__ static +__launch_bounds__(decltype(size(CThreadLayout{}))::value) +void +gemm_device(MShape M, NShape N, KShape K, + TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA, + TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB, + TC * C, CStride dC, CBlockLayout , CThreadLayout tC, + Alpha alpha, Beta beta) +{ + using namespace cute; + using X = Underscore; + + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + + CUTE_STATIC_ASSERT_V(size(tA) == size(tC)); + CUTE_STATIC_ASSERT_V(size(tB) == size(tC)); + + //CUTE_STATIC_ASSERT_V(shape<0>(blockA) == shape<0>(blockC)); // BLK_M + //CUTE_STATIC_ASSERT_V(shape<0>(blockB) == shape<1>(blockC)); // BLK_N + CUTE_STATIC_ASSERT_V(shape<1>(blockA) == shape<1>(blockB)); // BLK_K + + // Shared memory buffers + __shared__ TA smemA[cosize_v]; + __shared__ TB smemB[cosize_v]; + auto sA = make_tensor(make_smem_ptr(smemA), blockA); // (BLK_M,BLK_K) + auto sB = make_tensor(make_smem_ptr(smemB), blockB); // (BLK_N,BLK_K) + + // Represent the full tensors + auto mA = make_tensor(make_gmem_ptr(A), make_shape(M,K), dA); // (M,K) + auto mB = make_tensor(make_gmem_ptr(B), make_shape(N,K), dB); // (N,K) + auto mC = make_tensor(make_gmem_ptr(C), make_shape(M,N), dC); // (M,N) + + // Get the appropriate blocks for this thread block -- + // potential for thread block locality + auto blk_shape = make_shape(size<0>(sA), size<0>(sB), size<1>(sB));// (BLK_M,BLK_N,BLK_K) + auto blk_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) + + auto gA = local_tile(mA, blk_shape, blk_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) + auto gB = local_tile(mB, blk_shape, blk_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + auto gC = local_tile(mC, blk_shape, blk_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + // + // Partition the copying of A and B tiles across the threads + // + + // TUTORIAL: Example of simple partitioning of A|B tiles over tA|tB + // Default is a raked partition, but can be changed with Step parameter + + auto tAgA = local_partition(gA, tA, threadIdx.x); // (THR_M,THR_K,k) + auto tAsA = local_partition(sA, tA, threadIdx.x); // (THR_M,THR_K) + + auto tBgB = local_partition(gB, tB, threadIdx.x); // (THR_N,THR_K,k) + auto tBsB = local_partition(sB, tB, threadIdx.x); // (THR_N,THR_K) + + // + // Define C accumulators and A/B partitioning + // + + // TUTORIAL: Example of partitioning via projections of tC + + // Partition sA (M,K) by the rows of tC + auto tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{}); // (THR_M,BLK_K) + // Partition sB (N,K) by the cols of tC + auto tCsB = local_partition(sB, tC, threadIdx.x, Step< X,_1>{}); // (THR_N,BLK_K) + // Partition gC (M,N) by the tile of tC + auto tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{}); // (THR_M,THR_N) + + // Allocate the accumulators -- same size as the projected data + auto tCrC = make_fragment_like(tCgC); // (THR_M,THR_N) + + // Clear the accumulators + clear(tCrC); + +#if 1 + + // TUTORIAL: Example of a very simple compute loop + // Data is read from global to shared memory via the tA|tB partitioning + // gemm(.) operates on the shared memory directly via the tC partitioning + + auto k_max = size<2>(tAgA); + + for (int k = 0; k < k_max; ++k) + { + // Copy gmem to smem + copy(tAgA(_,_,k), tAsA); + copy(tBgB(_,_,k), tBsB); + + // In case copy uses cp.async, make sure that the cp.async + // instructions are ordered with respect to other cp.async + // instructions (fence), then wait on all the outstanding copy + // operations (wait<0>()). __syncthreads() alone does not do + // this. + // + // NOTE: cp_async_wait<0>() currently issues cp.async.wait_all. + // This is equivalent to cp.async.commit_group followed by + // cp.async_wait_group 0. This should make the first + // cp_async_fence() (which also issues cp.async.commit_group) + // redundant. The tutorial works as-is, so we'll leave the + // redundant fence in for now and study its removal later. + cp_async_fence(); + cp_async_wait<0>(); + + __syncthreads(); + + // Compute gemm on smem + gemm(tCsA, tCsB, tCrC); + + __syncthreads(); + } + +#endif + + axpby(alpha, tCrC, beta, tCgC); +} + //============================================================== // TEMPLATE DEFINITIONS diff --git a/csrc/ops.cu b/csrc/ops.cu index 022f397..1204cbd 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -665,6 +665,63 @@ template void extractOutliers(char * A, int *idx, char *out, int id } + +#include +#include + +#include + +template +void +gemm(int m, int n, int k, + Alpha alpha, + TA const* A, int ldA, + TB const* B, int ldB, + Beta beta, + TC * C, int ldC, + cudaStream_t stream = 0) +{ + using namespace cute; + + // Define shapes (dynamic) + auto M = int(m); + auto N = int(n); + auto K = int(k); + + // Define strides (mixed) + auto dA = make_stride(Int<1>{}, ldA); + auto dB = make_stride(Int<1>{}, ldB); + auto dC = make_stride(Int<1>{}, ldC); + + // Define block sizes (static) + auto bM = Int<128>{}; + auto bN = Int<128>{}; + auto bK = Int< 8>{}; + + // Define the block layouts (static) + auto sA = make_layout(make_shape(bM,bK)); + auto sB = make_layout(make_shape(bN,bK)); + auto sC = make_layout(make_shape(bM,bN)); + + // Define the thread layouts (static) + auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{})); + auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{})); + auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); + + dim3 dimBlock(size(tC)); + dim3 dimGrid(ceil_div(size(M), size(bM)), + ceil_div(size(N), size(bN))); + gemm_device + <<< dimGrid, dimBlock, 0, stream >>> + (M, N, K, + A, dA, sA, tA, + B, dB, sB, tB, + C, dC, sC, tC, + alpha, beta); +} + + //============================================================== // TEMPLATE DEFINITIONS //==============================================================