forked from mrq/bitsandbytes-rocm
Removed rowscale (segfaults on ampere).
This commit is contained in:
parent
8b1fd32e3e
commit
1e88edd8c0
1
Makefile
1
Makefile
|
@ -27,7 +27,6 @@ COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal
|
|||
COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
|
||||
|
||||
# CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not
|
||||
CC_CUDA92 := -gencode arch=compute_30,code=sm_30
|
||||
|
|
|
@ -897,7 +897,7 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr
|
|||
ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
|
||||
return out
|
||||
|
||||
def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32):
|
||||
def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
|
||||
shapeA = SA[0]
|
||||
shapeB = SB[0]
|
||||
dimsA = len(shapeA)
|
||||
|
@ -917,7 +917,6 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32
|
|||
elif dimsA == 3 and out is None:
|
||||
out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, 'col32', 'row')
|
||||
|
||||
if row_scale is not None: assert row_scale.numel() == out.shape[0]
|
||||
assert dimsB != 3, 'len(B.shape)==3 not supported'
|
||||
assert A.device.type == 'cuda'
|
||||
assert B.device.type == 'cuda'
|
||||
|
@ -936,7 +935,6 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32
|
|||
ptrA = get_ptr(A)
|
||||
ptrB = get_ptr(B)
|
||||
ptrC = get_ptr(out)
|
||||
ptrRowScale = get_ptr(row_scale)
|
||||
|
||||
k = shapeA[-1]
|
||||
lda = ct.c_int32(m*32)
|
||||
|
@ -955,20 +953,17 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32
|
|||
k = ct.c_int32(k)
|
||||
|
||||
has_error = 0
|
||||
ptrRowScale = get_ptr(None)
|
||||
if formatB == 'col_turing':
|
||||
if dtype == torch.int32:
|
||||
has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
||||
elif row_scale is None:
|
||||
has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
||||
else:
|
||||
has_error = lib.cigemmlt_turing_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
||||
has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
||||
elif formatB == 'col_ampere':
|
||||
if dtype == torch.int32:
|
||||
has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
||||
elif row_scale is None:
|
||||
has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
||||
else:
|
||||
has_error = lib.cigemmlt_ampere_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
||||
has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
||||
|
||||
if has_error == 1:
|
||||
raise Exception('cublasLt ran into an error!')
|
||||
|
|
|
@ -992,6 +992,7 @@ inner = torch.randint(1,4*1024, size=(n,)).tolist()
|
|||
values = list(zip(dim1, dim4, inner))
|
||||
names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
|
||||
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
|
||||
@pytest.mark.skip("Row scale has some bugs for ampere")
|
||||
def test_igemmlt_row_scale(dim1, dim4, inner):
|
||||
formatB = F.get_special_format_str()
|
||||
err1, err2, err3 = [], [], []
|
||||
|
@ -1064,6 +1065,7 @@ dim4 = [12288, 4096]
|
|||
values = list(zip(dim1, dim4, inner))
|
||||
names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
|
||||
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
|
||||
@pytest.mark.skip("Row scale has some bugs for ampere")
|
||||
def test_row_scale_bench(dim1, dim4, inner):
|
||||
err1, err2, err3 = [], [], []
|
||||
relerr1, relerr2 = [], []
|
||||
|
|
Loading…
Reference in New Issue
Block a user