Tighter and scaled error analysis.
This commit is contained in:
parent
f9bfea8f23
commit
9192c9de64
|
@ -3123,6 +3123,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
||||||
}
|
}
|
||||||
ticktock = ticktock == 0 ? 1 : 0;
|
ticktock = ticktock == 0 ? 1 : 0;
|
||||||
|
|
||||||
|
//for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
|
||||||
for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x-32)
|
for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x-32)
|
||||||
{
|
{
|
||||||
idx = base_idx + threadIdx.x;
|
idx = base_idx + threadIdx.x;
|
||||||
|
@ -3155,8 +3156,9 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
||||||
for(int col = 0; col < 32; col++)
|
for(int col = 0; col < 32; col++)
|
||||||
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
|
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
|
||||||
}
|
}
|
||||||
ticktock = ticktock == 0 ? 1 : 0;
|
//ticktock = ticktock == 0 ? 1 : 0;
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
if(warp_id == (WARPS-1))
|
if(warp_id == (WARPS-1))
|
||||||
for(int k = 0; k < batch_size_warps; k++)
|
for(int k = 0; k < batch_size_warps; k++)
|
||||||
{
|
{
|
||||||
|
@ -3166,11 +3168,22 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//__syncthreads();
|
||||||
|
//if(warp_id == (WARPS-1))
|
||||||
|
// for(int k = 0; k < batch_size_warps; k++)
|
||||||
|
// {
|
||||||
|
// wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
|
||||||
|
// wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
|
||||||
|
// wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
|
||||||
|
// }
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
// 129 mu
|
// 129 mu
|
||||||
if(warp_id == (WARPS-1))
|
if(warp_id == (WARPS-1))
|
||||||
wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major);
|
wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major);
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
|
||||||
//if(threadIdx.x >= 16){ return; }
|
//if(threadIdx.x >= 16){ return; }
|
||||||
//printf("%i %f\n", threadIdx.x, (float)smem_C[threadIdx.x]);
|
//printf("%i %f\n", threadIdx.x, (float)smem_C[threadIdx.x]);
|
||||||
|
|
||||||
|
|
|
@ -2355,13 +2355,18 @@ def test_normal_map_tree():
|
||||||
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
||||||
def test_cutlass3_gemm(dtype):
|
def test_cutlass3_gemm(dtype):
|
||||||
|
for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
||||||
|
errs = []
|
||||||
|
relerrs = []
|
||||||
|
max_err = 0
|
||||||
|
max_relerr = 0
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
|
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
|
||||||
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
|
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
|
||||||
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
|
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
|
||||||
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
|
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
|
||||||
A = torch.randn(1, 128+32, dtype=dtype, device='cuda')
|
A = torch.randn(1, dim+0, dtype=dtype, device='cuda')
|
||||||
B = torch.randn(4096, 128+32, dtype=dtype, device='cuda')/math.sqrt(128)
|
B = torch.randn(4*496, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||||
|
|
||||||
#print('')
|
#print('')
|
||||||
#print(A)
|
#print(A)
|
||||||
|
@ -2372,30 +2377,40 @@ def test_cutlass3_gemm(dtype):
|
||||||
|
|
||||||
C1 = torch.matmul(A, B.t())
|
C1 = torch.matmul(A, B.t())
|
||||||
C2 = F.cutlass3_gemm(A, B.t())
|
C2 = F.cutlass3_gemm(A, B.t())
|
||||||
err = C1-C2
|
|
||||||
|
|
||||||
# tensor cores are non-deterministic
|
# tensor cores are non-deterministic
|
||||||
# so we need to analyze errors around the mean
|
# so we need to analyze errors around the mean
|
||||||
# to test our implementation
|
# to test our implementation
|
||||||
err = torch.abs(err.mean()).item()
|
err = torch.abs(C1-C2)
|
||||||
mag = torch.abs(C1).mean()
|
mag = torch.abs(C1)+1e-8
|
||||||
relerr = err/mag
|
relerr = err/mag
|
||||||
|
max_err = max(err.max(), max_err)
|
||||||
|
max_relerr = max(relerr.max(), max_relerr)
|
||||||
|
err = err.mean().item()
|
||||||
|
relerr = relerr.mean().item()
|
||||||
|
|
||||||
if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
|
errs.append(err)
|
||||||
print('')
|
relerrs.append(relerr)
|
||||||
print(i, err, mag.item(), relerr.item())
|
|
||||||
print(A.flatten()[-6:])
|
|
||||||
print(B.flatten()[-6:])
|
|
||||||
out = A.flatten()[-6:]*B.flatten()[-6:]
|
|
||||||
print(out)
|
|
||||||
print(out[:-1].sum())
|
|
||||||
print('='*80)
|
|
||||||
print(C1.flatten()[-6:])
|
|
||||||
print(C2.flatten()[-6:])
|
|
||||||
#assert False, 'ERROR'
|
|
||||||
|
|
||||||
c = int(C1.numel()*0.001)
|
#if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
|
||||||
|
# print('')
|
||||||
|
# print(i, err, mag.item(), relerr.item())
|
||||||
|
# print(A.flatten()[-6:])
|
||||||
|
# print(B.flatten()[-6:])
|
||||||
|
# out = A.flatten()[-6:]*B.flatten()[-6:]
|
||||||
|
# print(out)
|
||||||
|
# print(out[:-1].sum())
|
||||||
|
# print('='*80)
|
||||||
|
# print(C1.flatten()[-6:])
|
||||||
|
# print(C2.flatten()[-6:])
|
||||||
|
# #assert False, 'ERROR'
|
||||||
|
|
||||||
|
c = int(C1.numel()*0.00125*(dim/256))+1
|
||||||
assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c)
|
assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c)
|
||||||
|
print('')
|
||||||
|
print(dim, sum(errs)/len(errs)/math.sqrt(dim))
|
||||||
|
print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
|
||||||
|
print(dim, (max_err.item(), max_relerr.item()))
|
||||||
|
|
||||||
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user