Adedd pipeline draft.
This commit is contained in:
parent
d1c4c20568
commit
9cab14a3ff
|
@ -2341,3 +2341,8 @@ def extract_outliers(A, SA, idx):
|
||||||
post_call(prev_device)
|
post_call(prev_device)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def pipeline_test(A, batch_size):
|
||||||
|
out = torch.zeros_like(A)
|
||||||
|
lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size))
|
||||||
|
return out
|
||||||
|
|
|
@ -15,6 +15,9 @@
|
||||||
#include <thrust/host_vector.h>
|
#include <thrust/host_vector.h>
|
||||||
#include <thrust/device_vector.h>
|
#include <thrust/device_vector.h>
|
||||||
|
|
||||||
|
#include <cooperative_groups/memcpy_async.h>
|
||||||
|
#include <cuda/pipeline>
|
||||||
|
|
||||||
#define HLF_MAX 65504
|
#define HLF_MAX 65504
|
||||||
#define TH 1024
|
#define TH 1024
|
||||||
#define NUM 4
|
#define NUM 4
|
||||||
|
@ -2983,6 +2986,51 @@ __global__ void gemm_device(int M, int N, int K,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__device__ void compute(float* global_out, float const* shared_in)
|
||||||
|
{
|
||||||
|
|
||||||
|
}
|
||||||
|
template <size_t stages_count /* Pipeline with stages_count stages */>
|
||||||
|
__global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz) {
|
||||||
|
auto grid = cooperative_groups::this_grid();
|
||||||
|
auto block = cooperative_groups::this_thread_block();
|
||||||
|
assert(size == batch_sz * grid.size()); // Assume input size fits batch_sz * grid_size
|
||||||
|
|
||||||
|
extern __shared__ float shared[]; // stages_count * block.size() * sizeof(int) bytes
|
||||||
|
size_t shared_offset[stages_count];
|
||||||
|
for (int s = 0; s < stages_count; ++s) shared_offset[s] = s * block.size();
|
||||||
|
|
||||||
|
__shared__ cuda::pipeline_shared_state<
|
||||||
|
cuda::thread_scope::thread_scope_block,
|
||||||
|
stages_count
|
||||||
|
> shared_state;
|
||||||
|
auto pipeline = cuda::make_pipeline(block, &shared_state);
|
||||||
|
|
||||||
|
auto block_batch = [&](size_t batch) -> int {
|
||||||
|
return block.group_index().x * block.size() + grid.size() * batch;
|
||||||
|
};
|
||||||
|
|
||||||
|
// compute_batch: next batch to process
|
||||||
|
// fetch_batch: next batch to fetch from global memory
|
||||||
|
for (size_t compute_batch = 0, fetch_batch = 0; compute_batch < batch_sz; ++compute_batch) {
|
||||||
|
// The outer loop iterates over the computation of the batches
|
||||||
|
for (; fetch_batch < batch_sz && fetch_batch < (compute_batch + stages_count); ++fetch_batch) {
|
||||||
|
// This inner loop iterates over the memory transfers, making sure that the pipeline is always full
|
||||||
|
pipeline.producer_acquire();
|
||||||
|
size_t shared_idx = fetch_batch % stages_count;
|
||||||
|
size_t batch_idx = fetch_batch;
|
||||||
|
size_t block_batch_idx = block_batch(batch_idx);
|
||||||
|
cuda::memcpy_async(block, shared + shared_offset[shared_idx], global_in + block_batch_idx, sizeof(float) * block.size(), pipeline);
|
||||||
|
pipeline.producer_commit();
|
||||||
|
}
|
||||||
|
pipeline.consumer_wait();
|
||||||
|
int shared_idx = compute_batch % stages_count;
|
||||||
|
int batch_idx = compute_batch;
|
||||||
|
compute(global_out + block_batch(batch_idx), shared + shared_offset[shared_idx]);
|
||||||
|
pipeline.consumer_release();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//==============================================================
|
//==============================================================
|
||||||
// TEMPLATE DEFINITIONS
|
// TEMPLATE DEFINITIONS
|
||||||
|
@ -3004,6 +3052,7 @@ __global__ void gemm_device(int M, int N, int K,
|
||||||
|
|
||||||
|
|
||||||
//template __global__ void kMatmul_inference_4bit<NF4, half, half, half>(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB);
|
//template __global__ void kMatmul_inference_4bit<NF4, half, half, half>(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB);
|
||||||
|
template __global__ void with_staging_unified<2>(float const* global_in, float * global_out, size_t size, size_t batch_sz);
|
||||||
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
||||||
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
||||||
|
|
||||||
|
|
|
@ -135,6 +135,8 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
|
||||||
// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
|
// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
|
||||||
// TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
|
// TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
|
||||||
// Alpha alpha, Beta beta);
|
// Alpha alpha, Beta beta);
|
||||||
|
template <size_t stages_count /* Pipeline with stages_count stages */>
|
||||||
|
__global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz);
|
||||||
|
|
||||||
__global__ void gemm_device(int M, int N, int K,
|
__global__ void gemm_device(int M, int N, int K,
|
||||||
float const* A,
|
float const* A,
|
||||||
|
|
11
csrc/ops.cu
11
csrc/ops.cu
|
@ -663,6 +663,17 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void pipeline_test(float *A, float *B, size_t n, size_t batch_size)
|
||||||
|
{
|
||||||
|
|
||||||
|
int threads = 256;
|
||||||
|
int num_blocks = (n+(256*batch_size)+1)/(batch_size*256);
|
||||||
|
|
||||||
|
printf("%i %i\n", num_blocks, batch_size);
|
||||||
|
|
||||||
|
with_staging_unified<2><<<num_blocks, threads>>>(A, B, n, batch_size);
|
||||||
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -197,4 +197,6 @@ void gemm_host(int m, int n, int k,
|
||||||
float beta,
|
float beta,
|
||||||
float * C, int ldC);
|
float * C, int ldC);
|
||||||
|
|
||||||
|
|
||||||
|
void pipeline_test(float *A, float *B, size_t n, size_t batch_size);
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -315,6 +315,7 @@ extern "C"
|
||||||
|
|
||||||
void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); }
|
void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); }
|
||||||
void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); }
|
void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); }
|
||||||
|
void cpipeline_test(float *A, float *B, size_t n, size_t batch_size){ pipeline_test(A, B, n, batch_size); }
|
||||||
|
|
||||||
void ccutlass_gemm(int m, int n, int k,
|
void ccutlass_gemm(int m, int n, int k,
|
||||||
float alpha,
|
float alpha,
|
||||||
|
|
|
@ -2366,3 +2366,8 @@ def test_cutlass3_gemm():
|
||||||
C2 = F.cutlass3_gemm(A, B)
|
C2 = F.cutlass3_gemm(A, B)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_func():
|
||||||
|
a = torch.rand(2, 4).cuda()
|
||||||
|
out = F.pipeline_test(a, 2)
|
||||||
|
print(a)
|
||||||
|
print(out)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user