Warp multi-specialization 240.

This commit is contained in:
Tim Dettmers 2023-05-02 12:10:32 -07:00
parent 77f15fdce9
commit 869b7e83b5
2 changed files with 56 additions and 14 deletions

View File

@ -3058,8 +3058,8 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
const int half_warp_lane = threadIdx.x % 16;
const int batch_size_warps = (WARPS-1)*2;
T local_A[1];
T local_B[32];
T local_A[2];
T local_B[64];
const int a_tile_offset = 16;
const int b_tile_offset = (16*32 + 16);
@ -3075,14 +3075,32 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
int ticktock = 0;
int idx = 0 + threadIdx.x;
int loaded_values = 0;
// prefetch
if(idx < K && warp_id < (WARPS-1))
{
local_A[0] = A[idx];
if(loaded_values == 0)
{
local_A[0] = A[idx];
local_A[1] = A[idx+blockDim.x-32];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = B[(col_offset+col)*ldb+idx];
#pragma unroll 32
for(int col = 0; col < 32; col++)
{
local_B[col] = B[(col_offset+col)*ldb+idx];
local_B[col+32] = B[(col_offset+col)*ldb+idx+blockDim.x-32];
}
loaded_values = 1;
}
else
{
local_A[0] = local_A[1];
loaded_values--;
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+32];
}
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
@ -3113,11 +3131,35 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
__syncthreads();
if(idx < K && warp_id < (WARPS-1))
{
local_A[0] = A[idx];
//local_A[0] = A[idx];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = B[(col_offset+col)*ldb+idx];
//#pragma unroll 32
//for(int col = 0; col < 32; col++)
// local_B[col] = B[(col_offset+col)*ldb+idx];
if(loaded_values == 0)
{
local_A[0] = A[idx];
local_A[1] = A[idx+blockDim.x-32];
#pragma unroll 32
for(int col = 0; col < 32; col++)
{
local_B[col] = B[(col_offset+col)*ldb+idx];
local_B[col+32] = B[(col_offset+col)*ldb+idx+blockDim.x-32];
}
loaded_values = 1;
}
else
{
local_A[0] = local_A[1];
loaded_values--;
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+32];
}
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];

View File

@ -2376,8 +2376,8 @@ def test_cutlass3_gemm(dtype):
#print('')
#print(A)
#print(B.t())
#A[:, :-3] = 0
#B[:, :-3] = 0
#A[:, :-1] = 0
#B[:, :-1] = 0
C1 = torch.matmul(A, B.t())
@ -2399,7 +2399,7 @@ def test_cutlass3_gemm(dtype):
#if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
# print('')
# print(i, err, mag.item(), relerr.item())
# print(i, err, relerr)
# print(A.flatten()[-6:])
# print(B.flatten()[-6:])
# out = A.flatten()[-6:]*B.flatten()[-6:]
@ -2412,7 +2412,7 @@ def test_cutlass3_gemm(dtype):
c = int(C1.numel()*0.0014*(dim/256))+1
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=True)
#print(c/math.sqrt(dim))
print('')
print(dim, sum(errs)/len(errs)/math.sqrt(dim))