Warp multi-specialization 240.
This commit is contained in:
parent
77f15fdce9
commit
869b7e83b5
|
@ -3058,8 +3058,8 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
|||
const int half_warp_lane = threadIdx.x % 16;
|
||||
const int batch_size_warps = (WARPS-1)*2;
|
||||
|
||||
T local_A[1];
|
||||
T local_B[32];
|
||||
T local_A[2];
|
||||
T local_B[64];
|
||||
|
||||
const int a_tile_offset = 16;
|
||||
const int b_tile_offset = (16*32 + 16);
|
||||
|
@ -3075,14 +3075,32 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
|||
|
||||
int ticktock = 0;
|
||||
int idx = 0 + threadIdx.x;
|
||||
int loaded_values = 0;
|
||||
// prefetch
|
||||
if(idx < K && warp_id < (WARPS-1))
|
||||
{
|
||||
local_A[0] = A[idx];
|
||||
if(loaded_values == 0)
|
||||
{
|
||||
local_A[0] = A[idx];
|
||||
local_A[1] = A[idx+blockDim.x-32];
|
||||
|
||||
#pragma unroll 32
|
||||
for(int col = 0; col < 32; col++)
|
||||
local_B[col] = B[(col_offset+col)*ldb+idx];
|
||||
#pragma unroll 32
|
||||
for(int col = 0; col < 32; col++)
|
||||
{
|
||||
local_B[col] = B[(col_offset+col)*ldb+idx];
|
||||
local_B[col+32] = B[(col_offset+col)*ldb+idx+blockDim.x-32];
|
||||
}
|
||||
loaded_values = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
local_A[0] = local_A[1];
|
||||
loaded_values--;
|
||||
|
||||
#pragma unroll 32
|
||||
for(int col = 0; col < 32; col++)
|
||||
local_B[col] = local_B[col+32];
|
||||
}
|
||||
|
||||
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
|
||||
|
||||
|
@ -3113,11 +3131,35 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
|||
__syncthreads();
|
||||
if(idx < K && warp_id < (WARPS-1))
|
||||
{
|
||||
local_A[0] = A[idx];
|
||||
//local_A[0] = A[idx];
|
||||
|
||||
#pragma unroll 32
|
||||
for(int col = 0; col < 32; col++)
|
||||
local_B[col] = B[(col_offset+col)*ldb+idx];
|
||||
//#pragma unroll 32
|
||||
//for(int col = 0; col < 32; col++)
|
||||
// local_B[col] = B[(col_offset+col)*ldb+idx];
|
||||
if(loaded_values == 0)
|
||||
{
|
||||
local_A[0] = A[idx];
|
||||
local_A[1] = A[idx+blockDim.x-32];
|
||||
|
||||
#pragma unroll 32
|
||||
for(int col = 0; col < 32; col++)
|
||||
{
|
||||
local_B[col] = B[(col_offset+col)*ldb+idx];
|
||||
local_B[col+32] = B[(col_offset+col)*ldb+idx+blockDim.x-32];
|
||||
}
|
||||
loaded_values = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
local_A[0] = local_A[1];
|
||||
loaded_values--;
|
||||
|
||||
#pragma unroll 32
|
||||
for(int col = 0; col < 32; col++)
|
||||
local_B[col] = local_B[col+32];
|
||||
|
||||
|
||||
}
|
||||
|
||||
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
|
||||
|
||||
|
|
|
@ -2376,8 +2376,8 @@ def test_cutlass3_gemm(dtype):
|
|||
#print('')
|
||||
#print(A)
|
||||
#print(B.t())
|
||||
#A[:, :-3] = 0
|
||||
#B[:, :-3] = 0
|
||||
#A[:, :-1] = 0
|
||||
#B[:, :-1] = 0
|
||||
|
||||
|
||||
C1 = torch.matmul(A, B.t())
|
||||
|
@ -2399,7 +2399,7 @@ def test_cutlass3_gemm(dtype):
|
|||
|
||||
#if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
|
||||
# print('')
|
||||
# print(i, err, mag.item(), relerr.item())
|
||||
# print(i, err, relerr)
|
||||
# print(A.flatten()[-6:])
|
||||
# print(B.flatten()[-6:])
|
||||
# out = A.flatten()[-6:]*B.flatten()[-6:]
|
||||
|
@ -2412,7 +2412,7 @@ def test_cutlass3_gemm(dtype):
|
|||
|
||||
c = int(C1.numel()*0.0014*(dim/256))+1
|
||||
|
||||
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
|
||||
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=True)
|
||||
#print(c/math.sqrt(dim))
|
||||
print('')
|
||||
print(dim, sum(errs)/len(errs)/math.sqrt(dim))
|
||||
|
|
Loading…
Reference in New Issue
Block a user