Added pre and post device call to transform.
This commit is contained in:
parent
320eacb4c2
commit
6101a8fb9f
|
@ -1214,6 +1214,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
|
|||
ptrA = get_ptr(A)
|
||||
ptrOut = get_ptr(out)
|
||||
is_on_gpu([A, out])
|
||||
prev_device = pre_call(A.device)
|
||||
if to_order == 'col32':
|
||||
if transpose:
|
||||
lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
|
||||
|
@ -1236,8 +1237,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
|
|||
lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
|
||||
else:
|
||||
raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}')
|
||||
|
||||
|
||||
post_call(prev_device)
|
||||
|
||||
|
||||
return out, new_state
|
||||
|
|
Loading…
Reference in New Issue
Block a user