diff --git a/tests/test_functional.py b/tests/test_functional.py index 0500984..808c1ce 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2406,6 +2406,7 @@ def test_cutlass3_gemm(dtype): # #assert False, 'ERROR' c = int(C1.numel()*0.00125*(dim/256))+1 + assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c) print('') print(dim, sum(errs)/len(errs)/math.sqrt(dim))