forked from mrq/bitsandbytes-rocm
Added pre/post device call for extract outliers.
This commit is contained in:
parent
cc5b323876
commit
ab72a1294f
|
@ -1198,6 +1198,7 @@ def get_special_format_str():
|
||||||
|
|
||||||
|
|
||||||
def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
|
def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
|
||||||
|
prev_device = pre_call(A.device)
|
||||||
if state is None: state = (A.shape, from_order)
|
if state is None: state = (A.shape, from_order)
|
||||||
else: from_order = state[1]
|
else: from_order = state[1]
|
||||||
if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
|
if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
|
||||||
|
@ -1214,7 +1215,6 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
|
||||||
ptrA = get_ptr(A)
|
ptrA = get_ptr(A)
|
||||||
ptrOut = get_ptr(out)
|
ptrOut = get_ptr(out)
|
||||||
is_on_gpu([A, out])
|
is_on_gpu([A, out])
|
||||||
prev_device = pre_call(A.device)
|
|
||||||
if to_order == 'col32':
|
if to_order == 'col32':
|
||||||
if transpose:
|
if transpose:
|
||||||
lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
|
lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
|
||||||
|
@ -1237,8 +1237,8 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
|
||||||
lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
|
lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}')
|
raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}')
|
||||||
post_call(prev_device)
|
|
||||||
|
|
||||||
|
post_call(prev_device)
|
||||||
|
|
||||||
return out, new_state
|
return out, new_state
|
||||||
|
|
||||||
|
@ -1451,10 +1451,12 @@ def extract_outliers(A, SA, idx):
|
||||||
ptrIdx = get_ptr(idx)
|
ptrIdx = get_ptr(idx)
|
||||||
ptrOut = get_ptr(out)
|
ptrOut = get_ptr(out)
|
||||||
|
|
||||||
|
prev_device = pre_call(A.device)
|
||||||
if formatA == 'col_turing':
|
if formatA == 'col_turing':
|
||||||
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
|
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
|
||||||
elif formatA == 'col_ampere':
|
elif formatA == 'col_ampere':
|
||||||
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
|
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
|
||||||
|
post_call(prev_device)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user