Baseline for debugging.
This commit is contained in:
parent
7bfa09d0fc
commit
f9bfea8f23
|
@ -1467,7 +1467,7 @@ def cutlass3_gemm(
|
||||||
lda = Bshape[1]
|
lda = Bshape[1]
|
||||||
ldc = Bshape[0]
|
ldc = Bshape[0]
|
||||||
ldb = (ldb+1)//2
|
ldb = (ldb+1)//2
|
||||||
print(m, n, k, lda, ldb, ldc)
|
#print(m, n, k, lda, ldb, ldc)
|
||||||
is_on_gpu([B, A, out])
|
is_on_gpu([B, A, out])
|
||||||
m = ct.c_int32(m)
|
m = ct.c_int32(m)
|
||||||
n = ct.c_int32(n)
|
n = ct.c_int32(n)
|
||||||
|
|
|
@ -3061,9 +3061,8 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
||||||
T local_A[1];
|
T local_A[1];
|
||||||
T local_B[32];
|
T local_B[32];
|
||||||
|
|
||||||
const int a_tile_offset = (8*16 + 16);
|
const int a_tile_offset = (8*16);
|
||||||
const int b_tile_offset = (16*32 + 16);
|
const int b_tile_offset = (16*32);
|
||||||
const int c_tile_offset = 8*32 + 24;
|
|
||||||
|
|
||||||
__shared__ T smem_A[2*batch_size_warps*8*16 + (2*16*(batch_size_warps-1))];
|
__shared__ T smem_A[2*batch_size_warps*8*16 + (2*16*(batch_size_warps-1))];
|
||||||
__shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))];
|
__shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))];
|
||||||
|
@ -3109,6 +3108,19 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
||||||
for(int col = 0; col < 32; col++)
|
for(int col = 0; col < 32; col++)
|
||||||
smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = local_B[col];
|
smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = local_B[col];
|
||||||
}
|
}
|
||||||
|
else if(warp_id < (WARPS-1))
|
||||||
|
{
|
||||||
|
local_A[0] = T(0.0);
|
||||||
|
smem_A[half_warp_lane + (half_warp_id*a_tile_offset)] = T(0.0);
|
||||||
|
|
||||||
|
#pragma unroll 32
|
||||||
|
for(int col = 0; col < 32; col++)
|
||||||
|
local_B[col] = T(0.0f);
|
||||||
|
|
||||||
|
#pragma unroll 32
|
||||||
|
for(int col = 0; col < 32; col++)
|
||||||
|
smem_B[half_warp_lane + (half_warp_id*b_tile_offset) + (col*16)] = T(0.0f);
|
||||||
|
}
|
||||||
ticktock = ticktock == 0 ? 1 : 0;
|
ticktock = ticktock == 0 ? 1 : 0;
|
||||||
|
|
||||||
for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x-32)
|
for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x-32)
|
||||||
|
@ -3130,6 +3142,19 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
||||||
for(int col = 0; col < 32; col++)
|
for(int col = 0; col < 32; col++)
|
||||||
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
|
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
|
||||||
}
|
}
|
||||||
|
else if(warp_id < (WARPS-1))
|
||||||
|
{
|
||||||
|
local_A[0] = T(0.0);
|
||||||
|
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f;
|
||||||
|
|
||||||
|
#pragma unroll 32
|
||||||
|
for(int col = 0; col < 32; col++)
|
||||||
|
local_B[col] = 0.0f;
|
||||||
|
|
||||||
|
#pragma unroll 32
|
||||||
|
for(int col = 0; col < 32; col++)
|
||||||
|
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
|
||||||
|
}
|
||||||
ticktock = ticktock == 0 ? 1 : 0;
|
ticktock = ticktock == 0 ? 1 : 0;
|
||||||
|
|
||||||
if(warp_id == (WARPS-1))
|
if(warp_id == (WARPS-1))
|
||||||
|
|
14
csrc/ops.cu
14
csrc/ops.cu
|
@ -680,14 +680,14 @@ template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out
|
||||||
|
|
||||||
int num_blocks = (m+31)/32;
|
int num_blocks = (m+31)/32;
|
||||||
|
|
||||||
cout << num_blocks << endl;
|
//cout << num_blocks << endl;
|
||||||
cout << lda << endl;
|
//cout << lda << endl;
|
||||||
cout << ldb << endl;
|
//cout << ldb << endl;
|
||||||
cout << ldc << endl;
|
//cout << ldc << endl;
|
||||||
|
|
||||||
cout << m << endl;
|
//cout << m << endl;
|
||||||
cout << n << endl;
|
//cout << n << endl;
|
||||||
cout << k << endl;
|
//cout << k << endl;
|
||||||
//if(bits == 32)
|
//if(bits == 32)
|
||||||
//gemm_device<T, 32, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
//gemm_device<T, 32, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||||
//gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
//gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||||
|
|
|
@ -2355,25 +2355,47 @@ def test_normal_map_tree():
|
||||||
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
||||||
def test_cutlass3_gemm(dtype):
|
def test_cutlass3_gemm(dtype):
|
||||||
for i in range(1):
|
for i in range(100):
|
||||||
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
|
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
|
||||||
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
|
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
|
||||||
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
|
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
|
||||||
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
|
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
|
||||||
A = torch.rand(1, 4096, dtype=dtype, device='cuda')
|
A = torch.randn(1, 128+32, dtype=dtype, device='cuda')
|
||||||
B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
|
B = torch.randn(4096, 128+32, dtype=dtype, device='cuda')/math.sqrt(128)
|
||||||
|
|
||||||
#print('')
|
#print('')
|
||||||
#print(A)
|
#print(A)
|
||||||
#print(B.t())
|
#print(B.t())
|
||||||
|
#A[:, :-3] = 0
|
||||||
|
#B[:, :-3] = 0
|
||||||
|
|
||||||
|
|
||||||
C1 = torch.matmul(A, B.t())
|
C1 = torch.matmul(A, B.t())
|
||||||
C2 = F.cutlass3_gemm(A, B.t())
|
C2 = F.cutlass3_gemm(A, B.t())
|
||||||
print(C1)
|
err = C1-C2
|
||||||
print(C2)
|
|
||||||
|
|
||||||
torch.testing.assert_close(C1, C2, atol=1e-05, rtol=0.06)
|
# tensor cores are non-deterministic
|
||||||
|
# so we need to analyze errors around the mean
|
||||||
|
# to test our implementation
|
||||||
|
err = torch.abs(err.mean()).item()
|
||||||
|
mag = torch.abs(C1).mean()
|
||||||
|
relerr = err/mag
|
||||||
|
|
||||||
|
if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
|
||||||
|
print('')
|
||||||
|
print(i, err, mag.item(), relerr.item())
|
||||||
|
print(A.flatten()[-6:])
|
||||||
|
print(B.flatten()[-6:])
|
||||||
|
out = A.flatten()[-6:]*B.flatten()[-6:]
|
||||||
|
print(out)
|
||||||
|
print(out[:-1].sum())
|
||||||
|
print('='*80)
|
||||||
|
print(C1.flatten()[-6:])
|
||||||
|
print(C2.flatten()[-6:])
|
||||||
|
#assert False, 'ERROR'
|
||||||
|
|
||||||
|
c = int(C1.numel()*0.001)
|
||||||
|
assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c)
|
||||||
|
|
||||||
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user