forked from mrq/bitsandbytes-rocm
Boilerplate and test for extract_outliers.
This commit is contained in:
parent
c771b3a75a
commit
cbb901ac51
|
@ -1409,3 +1409,29 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half):
|
|||
x *= SA[1]/127
|
||||
x +=offset
|
||||
return x.to(dtype)
|
||||
|
||||
def extract_outliers(A, SA, idx):
|
||||
shapeA = SA[0]
|
||||
formatA = SA[1]
|
||||
assert formatA in ['col_turing', 'col_ampere']
|
||||
assert A.device.type == 'cuda'
|
||||
|
||||
out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device)
|
||||
|
||||
idx_size = ct.c_int32(idx.numel())
|
||||
rows = ct.c_int32(shapeA[0])
|
||||
cols = ct.c_int32(shapeA[1])
|
||||
ptrA = get_ptr(A)
|
||||
ptrIdx = get_ptr(idx)
|
||||
ptrOut = get_ptr(out)
|
||||
|
||||
if formatA == 'col_turing':
|
||||
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
|
||||
elif formatA == 'col_ampere':
|
||||
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -2592,10 +2592,17 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
|
|||
}
|
||||
}
|
||||
|
||||
template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA)
|
||||
{
|
||||
}
|
||||
|
||||
//==============================================================
|
||||
// TEMPLATE DEFINITIONS
|
||||
//==============================================================
|
||||
|
||||
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
||||
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
||||
|
||||
template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
|
||||
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
|
||||
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
|
||||
|
|
|
@ -118,6 +118,8 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int S
|
|||
|
||||
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
|
||||
|
||||
template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
|
|
27
csrc/ops.cu
27
csrc/ops.cu
|
@ -578,10 +578,37 @@ template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count,
|
|||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
||||
template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols)
|
||||
{
|
||||
int threads = 256;
|
||||
// we load 128 column values per warp
|
||||
int tiledCols = tiledCols = fill_up_to_nearest_multiple(cols, 32);
|
||||
int tiledRows = 0;
|
||||
|
||||
int elements = idx_size*cols; // matrix A is transposed, so we extract columns
|
||||
int num_blocks = (elements+threads-1)/threads;
|
||||
|
||||
if(FORMAT == COL_TURING)
|
||||
{
|
||||
tiledRows = fill_up_to_nearest_multiple(rows, 8);
|
||||
}
|
||||
else if(FORMAT == COL_AMPERE)
|
||||
{
|
||||
tiledRows = fill_up_to_nearest_multiple(rows, 32);
|
||||
}
|
||||
|
||||
kExtractOutliers<FORMAT><<<num_blocks, threads>>>(A, idx, out, rows, cols, tiledRows, tiledCols);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
//==============================================================
|
||||
// TEMPLATE DEFINITIONS
|
||||
//==============================================================
|
||||
|
||||
template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
|
||||
template void extractOutliers<COL_AMPERE>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
|
||||
|
||||
template void spmm_coo_very_sparse_naive<half, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
|
||||
template void spmm_coo_very_sparse_naive<signed char, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
|
||||
|
||||
|
|
|
@ -174,4 +174,6 @@ void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_val
|
|||
|
||||
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
|
||||
|
||||
template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols);
|
||||
|
||||
#endif
|
||||
|
|
|
@ -106,6 +106,9 @@ void transform_row2turingT(char * A, char *out, int rows, int cols){ transformRo
|
|||
void transform_row2ampere(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_AMPERE, 0>(A, out, rows, cols); }
|
||||
void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_AMPERE, 1>(A, out, rows, cols); }
|
||||
|
||||
void extractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers<COL_TURING>(A, idx, out, idx_size, rows, cols); }
|
||||
void extractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers<COL_AMPERE>(A, idx, out, idx_size, rows, cols); }
|
||||
|
||||
int igemmlt_turing_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
|
||||
{ return igemmlt<COL_TURING, 32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
|
||||
|
||||
|
@ -280,6 +283,9 @@ extern "C"
|
|||
void cspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
|
||||
{ spmm_coo_very_sparse_naive_int8(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
|
||||
|
||||
void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); }
|
||||
void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); }
|
||||
|
||||
#endif
|
||||
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); }
|
||||
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); }
|
||||
|
|
|
@ -1856,3 +1856,20 @@ def test_zp():
|
|||
print(err1, err2, err3, err4, err5, err6)
|
||||
|
||||
|
||||
|
||||
def test_extract_outliers():
|
||||
shapeA = (128, 128)
|
||||
idx = torch.randint(0, shapeA[1], size=(10,)).int()
|
||||
A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8)
|
||||
outliers1 = A[:, idx.long()]
|
||||
|
||||
CA, SA = F.transform(A, 'col_turing')
|
||||
|
||||
outliers2 = F.extract_outliers(CA, SA, idx)
|
||||
|
||||
assert outliers2.shape[0] == shapeA[0]
|
||||
assert outliers2.shape[1] == idx.numel()
|
||||
|
||||
|
||||
|
||||
torch.testing.assert_allclose(outliers1, outliers2)
|
||||
|
|
Loading…
Reference in New Issue
Block a user