Added bit template.
This commit is contained in:
parent
f3e97ccbd2
commit
cad839941b
|
@ -2947,16 +2947,31 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
|
|||
//// 9. write outputs to matmul output matrix
|
||||
//}
|
||||
|
||||
#define ROWS 2
|
||||
template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
|
||||
template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_load(T *local, T * __restrict__ const buffer, int idx, int limit_base, int limit)
|
||||
{
|
||||
if(limit_base + ITEMS <= limit)
|
||||
reinterpret_cast<TCAST*>(local)[0] = reinterpret_cast<TCAST*>(buffer)[idx/ITEMS];
|
||||
else
|
||||
{
|
||||
for(int k = 0; k < ITEMS; k++)
|
||||
{
|
||||
if(limit_base + k < limit)
|
||||
local[k] = buffer[idx+k];
|
||||
else
|
||||
local[k] = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
|
||||
{
|
||||
|
||||
typedef cub::BlockReduce<T, THREADS> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage reduce;
|
||||
int col_offset = blockIdx.x *8;
|
||||
|
||||
T local_A[8];
|
||||
T local_B[8];
|
||||
T local_A[128/BITS];
|
||||
T local_B[128/BITS];
|
||||
T local_C[8];
|
||||
|
||||
__shared__ T smem_C[8];
|
||||
|
@ -2970,47 +2985,18 @@ template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M,
|
|||
local_C[k] = T(0);
|
||||
|
||||
|
||||
for(int idx = threadIdx.x*8; idx < K; idx+=blockDim.x*8)
|
||||
for(int idx = threadIdx.x*128/BITS; idx < K; idx+=blockDim.x*128/BITS)
|
||||
{
|
||||
|
||||
if(idx + 8 <= K)
|
||||
reinterpret_cast<float4(&)[8]>(local_A)[0] = reinterpret_cast<float4*>(A)[idx/8];
|
||||
else
|
||||
{
|
||||
for(int k = 0; k < 8; k++)
|
||||
{
|
||||
if(idx + k < K)
|
||||
local_A[k] = A[idx+k];
|
||||
else
|
||||
local_A[k] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
vector_load<T, int4, 128/BITS>(local_A, A, idx, idx, K);
|
||||
|
||||
for(int col = 0; col < 8; col++)
|
||||
{
|
||||
int offset_B = (col_offset+col)*ldb;
|
||||
if(idx + 8 <= K)
|
||||
reinterpret_cast<float4(&)[8]>(local_B)[0] = reinterpret_cast<float4*>(B)[(offset_B+idx)/8];
|
||||
else
|
||||
{
|
||||
for(int k = 0; k < 8; k++)
|
||||
{
|
||||
if(idx + k < K)
|
||||
local_B[k] = B[(offset_B+idx)+k];
|
||||
else
|
||||
local_B[k] = 0.0f;
|
||||
}
|
||||
}
|
||||
vector_load<T, int4, 128/BITS>(local_B, B, offset_B+idx, idx, K);
|
||||
|
||||
#pragma unroll 8
|
||||
for(int k = 0; k < 8; k++)
|
||||
{
|
||||
#pragma unroll 128/BITS
|
||||
for(int k = 0; k < 128/BITS; k++)
|
||||
local_C[col] += local_A[k]*local_B[k];
|
||||
//if((float)local_A[k] != 0.0 && (float)local_B[k] != 0.0)
|
||||
// printf("%i %i %f %f %f\n", k, threadIdx.x, (float)local_A[k], (float)local_B[k], (float)local_C[col]);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3022,9 +3008,11 @@ template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M,
|
|||
}
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
#pragma unroll 8
|
||||
for(int k = 0; k < 8; k++)
|
||||
smem_C[k] = local_C[k];
|
||||
}
|
||||
else if(threadIdx.x >= 32)
|
||||
// early return for unused warps
|
||||
return;
|
||||
|
@ -3032,15 +3020,8 @@ template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M,
|
|||
__syncwarp();
|
||||
|
||||
|
||||
//for(int k = 0; k < 8; k++)
|
||||
// if((float)local_C[k] != 0.0f)
|
||||
// printf("%i %f\n", threadIdx.x, (float)local_C[k]);
|
||||
|
||||
if(threadIdx.x < 8 && col_offset + threadIdx.x < M)
|
||||
out[col_offset + threadIdx.x ] = smem_C[threadIdx.x];
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
//#define ROWS 2
|
||||
|
@ -3217,7 +3198,13 @@ __global__ void with_staging_unified(float const* global_in, float * global_out,
|
|||
// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
|
||||
// TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
|
||||
// half alpha, half beta);
|
||||
|
||||
// these are not used and make no sense, but the compiler needs them
|
||||
template __global__ void gemm_device<float, 16, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
|
||||
template __global__ void gemm_device<half, 32, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||
// these are not used and make no sense, but the compiler needs them
|
||||
|
||||
template __global__ void gemm_device<float, 32, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
|
||||
template __global__ void gemm_device<half, 16, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||
|
||||
|
||||
|
|
|
@ -138,6 +138,6 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
|
|||
template <size_t stages_count /* Pipeline with stages_count stages */>
|
||||
__global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz);
|
||||
|
||||
template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc);
|
||||
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc);
|
||||
|
||||
#endif
|
||||
|
|
16
csrc/ops.cu
16
csrc/ops.cu
|
@ -675,7 +675,7 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size)
|
|||
|
||||
|
||||
|
||||
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc)
|
||||
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits)
|
||||
{
|
||||
|
||||
dim3 dimBlock(128);
|
||||
|
@ -689,20 +689,18 @@ template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out
|
|||
cout << m << endl;
|
||||
cout << n << endl;
|
||||
cout << k << endl;
|
||||
gemm_device<T, 16, 128>
|
||||
<<< num_blocks, dimBlock, 0, 0 >>>
|
||||
(m, n, k,
|
||||
A,
|
||||
B,
|
||||
out, lda, ldb, ldc);
|
||||
if(bits == 32)
|
||||
gemm_device<T, 32, 128><<< num_blocks, dimBlock, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||
else if(bits == 16)
|
||||
gemm_device<T, 16, 128><<< num_blocks, dimBlock, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||
}
|
||||
|
||||
//==============================================================
|
||||
// TEMPLATE DEFINITIONS
|
||||
//==============================================================
|
||||
|
||||
template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc);
|
||||
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc);
|
||||
template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
|
||||
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
|
||||
template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
|
||||
template void extractOutliers<COL_AMPERE>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
|
||||
|
||||
|
|
|
@ -190,7 +190,7 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
|
|||
|
||||
void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB);
|
||||
|
||||
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc);
|
||||
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits);
|
||||
|
||||
|
||||
void pipeline_test(float *A, float *B, size_t n, size_t batch_size);
|
||||
|
|
|
@ -21,9 +21,9 @@ void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimate
|
|||
|
||||
|
||||
void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
|
||||
{ gemm_host<float>(M, N, K, A, B, out, lda, ldb, ldc); }
|
||||
{ gemm_host<float>(M, N, K, A, B, out, lda, ldb, ldc, 32); }
|
||||
void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc)
|
||||
{ gemm_host<half>(M, N, K, A, B, out, lda, ldb, ldc); }
|
||||
{ gemm_host<half>(M, N, K, A, B, out, lda, ldb, ldc, 16); }
|
||||
|
||||
|
||||
#define MAKE_FUNC32(fname, oname, gtype, gbits) \
|
||||
|
|
|
@ -2352,8 +2352,8 @@ def test_normal_map_tree():
|
|||
print(pivots)
|
||||
|
||||
|
||||
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
||||
#@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
||||
def test_cutlass3_gemm(dtype):
|
||||
for i in range(1):
|
||||
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
|
||||
|
|
Loading…
Reference in New Issue
Block a user