Added col_ampere outlier extraction kernel.
This commit is contained in:
parent
bcab99ec87
commit
32fa459ed7
|
@ -2626,28 +2626,32 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
|
|||
|
||||
offset += tile_offset_rows + tile_offset_cols;
|
||||
|
||||
|
||||
char val = 0;
|
||||
//printf("(%i (%i %i) (%i %i))\n", offset, tile_offset_rows, tile_offset_cols, row, local_colidx);
|
||||
if(offset > tiledColsA*tiledRowsA)
|
||||
printf("(%i (%i %i) (%i %i)\n", offset, tile_offset_rows, tile_offset_cols, row, local_colidx);
|
||||
else
|
||||
val = A[offset];
|
||||
char val = A[offset];
|
||||
|
||||
int out_idx = (row*idx_size) + blockIdx.x;
|
||||
|
||||
//if(out_idx > colsA*idx_size)
|
||||
if(val != 0)
|
||||
{
|
||||
//printf("(%i %i) = (%i) = %i\n", row, local_colidx, out_idx, (int) val);
|
||||
out[out_idx] = val;
|
||||
}
|
||||
else
|
||||
{
|
||||
out[out_idx] = val;
|
||||
}
|
||||
}
|
||||
else if(FORMAT == COL_AMPERE)
|
||||
{
|
||||
|
||||
for(int row = threadIdx.x; row < rowsA; row+= blockDim.x)
|
||||
{
|
||||
// we got 32x32 tiles and we use the magic equation from the cublasLt doc to get the element
|
||||
// within each tile.
|
||||
int offset_per_col_tile = ((rowsA+31)/32)*32*32;
|
||||
int tile_offset_rows = (row/32)*32*32;
|
||||
int tile_offset_cols = (local_colidx/32)*offset_per_col_tile;
|
||||
int subtile_col_idx = local_colidx%32;
|
||||
int subtile_row_idx = row % 32;
|
||||
// this magic is taken from the cublasLt doc (search for COL32)
|
||||
int offset = (((subtile_row_idx%8)/2*4+subtile_row_idx/8)*2+subtile_row_idx%2)*32+subtile_col_idx;
|
||||
offset += tile_offset_cols + tile_offset_rows;
|
||||
|
||||
char val = A[offset];
|
||||
int out_idx = (row*idx_size) + blockIdx.x;
|
||||
out[out_idx] = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1859,9 +1859,9 @@ def test_zp():
|
|||
|
||||
def test_extract_outliers():
|
||||
for i in range(k):
|
||||
shapeA = (4096, 4*4096)
|
||||
shapeA = (4096, 4096*4)
|
||||
idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
|
||||
#idx = torch.Tensor([32]).int().cuda()
|
||||
#idx = torch.Tensor([0]).int().cuda()
|
||||
A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8)
|
||||
outliers1 = A[:, idx.long()]
|
||||
|
||||
|
@ -1872,7 +1872,13 @@ def test_extract_outliers():
|
|||
assert outliers2.shape[0] == shapeA[0]
|
||||
assert outliers2.shape[1] == idx.numel()
|
||||
|
||||
#print(outliers1)
|
||||
#print(outliers2)
|
||||
torch.testing.assert_allclose(outliers1, outliers2)
|
||||
|
||||
CA, SA = F.transform(A, 'col_ampere')
|
||||
|
||||
outliers2 = F.extract_outliers(CA, SA, idx)
|
||||
|
||||
assert outliers2.shape[0] == shapeA[0]
|
||||
assert outliers2.shape[1] == idx.numel()
|
||||
|
||||
torch.testing.assert_allclose(outliers1, outliers2)
|
||||
|
|
Loading…
Reference in New Issue
Block a user