Removed rowscale (segfaults on ampere).
This commit is contained in:
parent
8b1fd32e3e
commit
1e88edd8c0
1
Makefile
1
Makefile
|
@ -27,7 +27,6 @@ COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal
|
||||||
COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
|
COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
|
||||||
COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
|
COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
|
||||||
COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
|
COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
|
||||||
COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
|
|
||||||
|
|
||||||
# CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not
|
# CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not
|
||||||
CC_CUDA92 := -gencode arch=compute_30,code=sm_30
|
CC_CUDA92 := -gencode arch=compute_30,code=sm_30
|
||||||
|
|
|
@ -897,7 +897,7 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr
|
||||||
ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
|
ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32):
|
def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
|
||||||
shapeA = SA[0]
|
shapeA = SA[0]
|
||||||
shapeB = SB[0]
|
shapeB = SB[0]
|
||||||
dimsA = len(shapeA)
|
dimsA = len(shapeA)
|
||||||
|
@ -917,7 +917,6 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32
|
||||||
elif dimsA == 3 and out is None:
|
elif dimsA == 3 and out is None:
|
||||||
out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, 'col32', 'row')
|
out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, 'col32', 'row')
|
||||||
|
|
||||||
if row_scale is not None: assert row_scale.numel() == out.shape[0]
|
|
||||||
assert dimsB != 3, 'len(B.shape)==3 not supported'
|
assert dimsB != 3, 'len(B.shape)==3 not supported'
|
||||||
assert A.device.type == 'cuda'
|
assert A.device.type == 'cuda'
|
||||||
assert B.device.type == 'cuda'
|
assert B.device.type == 'cuda'
|
||||||
|
@ -936,7 +935,6 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32
|
||||||
ptrA = get_ptr(A)
|
ptrA = get_ptr(A)
|
||||||
ptrB = get_ptr(B)
|
ptrB = get_ptr(B)
|
||||||
ptrC = get_ptr(out)
|
ptrC = get_ptr(out)
|
||||||
ptrRowScale = get_ptr(row_scale)
|
|
||||||
|
|
||||||
k = shapeA[-1]
|
k = shapeA[-1]
|
||||||
lda = ct.c_int32(m*32)
|
lda = ct.c_int32(m*32)
|
||||||
|
@ -955,20 +953,17 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32
|
||||||
k = ct.c_int32(k)
|
k = ct.c_int32(k)
|
||||||
|
|
||||||
has_error = 0
|
has_error = 0
|
||||||
|
ptrRowScale = get_ptr(None)
|
||||||
if formatB == 'col_turing':
|
if formatB == 'col_turing':
|
||||||
if dtype == torch.int32:
|
if dtype == torch.int32:
|
||||||
has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
||||||
elif row_scale is None:
|
|
||||||
has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
|
||||||
else:
|
else:
|
||||||
has_error = lib.cigemmlt_turing_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
||||||
elif formatB == 'col_ampere':
|
elif formatB == 'col_ampere':
|
||||||
if dtype == torch.int32:
|
if dtype == torch.int32:
|
||||||
has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
||||||
elif row_scale is None:
|
|
||||||
has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
|
||||||
else:
|
else:
|
||||||
has_error = lib.cigemmlt_ampere_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
||||||
|
|
||||||
if has_error == 1:
|
if has_error == 1:
|
||||||
raise Exception('cublasLt ran into an error!')
|
raise Exception('cublasLt ran into an error!')
|
||||||
|
|
|
@ -992,6 +992,7 @@ inner = torch.randint(1,4*1024, size=(n,)).tolist()
|
||||||
values = list(zip(dim1, dim4, inner))
|
values = list(zip(dim1, dim4, inner))
|
||||||
names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
|
names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
|
||||||
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
|
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
|
||||||
|
@pytest.mark.skip("Row scale has some bugs for ampere")
|
||||||
def test_igemmlt_row_scale(dim1, dim4, inner):
|
def test_igemmlt_row_scale(dim1, dim4, inner):
|
||||||
formatB = F.get_special_format_str()
|
formatB = F.get_special_format_str()
|
||||||
err1, err2, err3 = [], [], []
|
err1, err2, err3 = [], [], []
|
||||||
|
@ -1064,6 +1065,7 @@ dim4 = [12288, 4096]
|
||||||
values = list(zip(dim1, dim4, inner))
|
values = list(zip(dim1, dim4, inner))
|
||||||
names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
|
names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
|
||||||
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
|
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
|
||||||
|
@pytest.mark.skip("Row scale has some bugs for ampere")
|
||||||
def test_row_scale_bench(dim1, dim4, inner):
|
def test_row_scale_bench(dim1, dim4, inner):
|
||||||
err1, err2, err3 = [], [], []
|
err1, err2, err3 = [], [], []
|
||||||
relerr1, relerr2 = [], []
|
relerr1, relerr2 = [], []
|
||||||
|
|
Loading…
Reference in New Issue
Block a user