Warp multi-specialization 240.
This commit is contained in:
parent
77f15fdce9
commit
869b7e83b5
@ -3058,8 +3058,8 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
|||||||
const int half_warp_lane = threadIdx.x % 16;
|
const int half_warp_lane = threadIdx.x % 16;
|
||||||
const int batch_size_warps = (WARPS-1)*2;
|
const int batch_size_warps = (WARPS-1)*2;
|
||||||
|
|
||||||
T local_A[1];
|
T local_A[2];
|
||||||
T local_B[32];
|
T local_B[64];
|
||||||
|
|
||||||
const int a_tile_offset = 16;
|
const int a_tile_offset = 16;
|
||||||
const int b_tile_offset = (16*32 + 16);
|
const int b_tile_offset = (16*32 + 16);
|
||||||
@ -3075,14 +3075,32 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
|||||||
|
|
||||||
int ticktock = 0;
|
int ticktock = 0;
|
||||||
int idx = 0 + threadIdx.x;
|
int idx = 0 + threadIdx.x;
|
||||||
|
int loaded_values = 0;
|
||||||
// prefetch
|
// prefetch
|
||||||
if(idx < K && warp_id < (WARPS-1))
|
if(idx < K && warp_id < (WARPS-1))
|
||||||
{
|
{
|
||||||
local_A[0] = A[idx];
|
if(loaded_values == 0)
|
||||||
|
{
|
||||||
|
local_A[0] = A[idx];
|
||||||
|
local_A[1] = A[idx+blockDim.x-32];
|
||||||
|
|
||||||
#pragma unroll 32
|
#pragma unroll 32
|
||||||
for(int col = 0; col < 32; col++)
|
for(int col = 0; col < 32; col++)
|
||||||
local_B[col] = B[(col_offset+col)*ldb+idx];
|
{
|
||||||
|
local_B[col] = B[(col_offset+col)*ldb+idx];
|
||||||
|
local_B[col+32] = B[(col_offset+col)*ldb+idx+blockDim.x-32];
|
||||||
|
}
|
||||||
|
loaded_values = 1;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
local_A[0] = local_A[1];
|
||||||
|
loaded_values--;
|
||||||
|
|
||||||
|
#pragma unroll 32
|
||||||
|
for(int col = 0; col < 32; col++)
|
||||||
|
local_B[col] = local_B[col+32];
|
||||||
|
}
|
||||||
|
|
||||||
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
|
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
|
||||||
|
|
||||||
@ -3113,11 +3131,35 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
if(idx < K && warp_id < (WARPS-1))
|
if(idx < K && warp_id < (WARPS-1))
|
||||||
{
|
{
|
||||||
local_A[0] = A[idx];
|
//local_A[0] = A[idx];
|
||||||
|
|
||||||
#pragma unroll 32
|
//#pragma unroll 32
|
||||||
for(int col = 0; col < 32; col++)
|
//for(int col = 0; col < 32; col++)
|
||||||
local_B[col] = B[(col_offset+col)*ldb+idx];
|
// local_B[col] = B[(col_offset+col)*ldb+idx];
|
||||||
|
if(loaded_values == 0)
|
||||||
|
{
|
||||||
|
local_A[0] = A[idx];
|
||||||
|
local_A[1] = A[idx+blockDim.x-32];
|
||||||
|
|
||||||
|
#pragma unroll 32
|
||||||
|
for(int col = 0; col < 32; col++)
|
||||||
|
{
|
||||||
|
local_B[col] = B[(col_offset+col)*ldb+idx];
|
||||||
|
local_B[col+32] = B[(col_offset+col)*ldb+idx+blockDim.x-32];
|
||||||
|
}
|
||||||
|
loaded_values = 1;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
local_A[0] = local_A[1];
|
||||||
|
loaded_values--;
|
||||||
|
|
||||||
|
#pragma unroll 32
|
||||||
|
for(int col = 0; col < 32; col++)
|
||||||
|
local_B[col] = local_B[col+32];
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
|
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
|
||||||
|
|
||||||
|
@ -2376,8 +2376,8 @@ def test_cutlass3_gemm(dtype):
|
|||||||
#print('')
|
#print('')
|
||||||
#print(A)
|
#print(A)
|
||||||
#print(B.t())
|
#print(B.t())
|
||||||
#A[:, :-3] = 0
|
#A[:, :-1] = 0
|
||||||
#B[:, :-3] = 0
|
#B[:, :-1] = 0
|
||||||
|
|
||||||
|
|
||||||
C1 = torch.matmul(A, B.t())
|
C1 = torch.matmul(A, B.t())
|
||||||
@ -2399,7 +2399,7 @@ def test_cutlass3_gemm(dtype):
|
|||||||
|
|
||||||
#if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
|
#if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
|
||||||
# print('')
|
# print('')
|
||||||
# print(i, err, mag.item(), relerr.item())
|
# print(i, err, relerr)
|
||||||
# print(A.flatten()[-6:])
|
# print(A.flatten()[-6:])
|
||||||
# print(B.flatten()[-6:])
|
# print(B.flatten()[-6:])
|
||||||
# out = A.flatten()[-6:]*B.flatten()[-6:]
|
# out = A.flatten()[-6:]*B.flatten()[-6:]
|
||||||
@ -2412,7 +2412,7 @@ def test_cutlass3_gemm(dtype):
|
|||||||
|
|
||||||
c = int(C1.numel()*0.0014*(dim/256))+1
|
c = int(C1.numel()*0.0014*(dim/256))+1
|
||||||
|
|
||||||
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
|
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=True)
|
||||||
#print(c/math.sqrt(dim))
|
#print(c/math.sqrt(dim))
|
||||||
print('')
|
print('')
|
||||||
print(dim, sum(errs)/len(errs)/math.sqrt(dim))
|
print(dim, sum(errs)/len(errs)/math.sqrt(dim))
|
||||||
|
Loading…
Reference in New Issue
Block a user