Added col_ampere outlier extraction kernel.

This commit is contained in:
Tim Dettmers 2022-07-26 18:15:51 -07:00
parent bcab99ec87
commit 32fa459ed7
2 changed files with 32 additions and 22 deletions

View File

@ -2626,28 +2626,32 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
offset += tile_offset_rows + tile_offset_cols; offset += tile_offset_rows + tile_offset_cols;
char val = A[offset];
char val = 0;
//printf("(%i (%i %i) (%i %i))\n", offset, tile_offset_rows, tile_offset_cols, row, local_colidx);
if(offset > tiledColsA*tiledRowsA)
printf("(%i (%i %i) (%i %i)\n", offset, tile_offset_rows, tile_offset_cols, row, local_colidx);
else
val = A[offset];
int out_idx = (row*idx_size) + blockIdx.x; int out_idx = (row*idx_size) + blockIdx.x;
//if(out_idx > colsA*idx_size)
if(val != 0)
{
//printf("(%i %i) = (%i) = %i\n", row, local_colidx, out_idx, (int) val);
out[out_idx] = val;
}
else
{
out[out_idx] = val; out[out_idx] = val;
} }
} }
else if(FORMAT == COL_AMPERE)
{
for(int row = threadIdx.x; row < rowsA; row+= blockDim.x)
{
// we got 32x32 tiles and we use the magic equation from the cublasLt doc to get the element
// within each tile.
int offset_per_col_tile = ((rowsA+31)/32)*32*32;
int tile_offset_rows = (row/32)*32*32;
int tile_offset_cols = (local_colidx/32)*offset_per_col_tile;
int subtile_col_idx = local_colidx%32;
int subtile_row_idx = row % 32;
// this magic is taken from the cublasLt doc (search for COL32)
int offset = (((subtile_row_idx%8)/2*4+subtile_row_idx/8)*2+subtile_row_idx%2)*32+subtile_col_idx;
offset += tile_offset_cols + tile_offset_rows;
char val = A[offset];
int out_idx = (row*idx_size) + blockIdx.x;
out[out_idx] = val;
}
} }
} }

View File

@ -1859,9 +1859,9 @@ def test_zp():
def test_extract_outliers(): def test_extract_outliers():
for i in range(k): for i in range(k):
shapeA = (4096, 4*4096) shapeA = (4096, 4096*4)
idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda() idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
#idx = torch.Tensor([32]).int().cuda() #idx = torch.Tensor([0]).int().cuda()
A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8)
outliers1 = A[:, idx.long()] outliers1 = A[:, idx.long()]
@ -1872,7 +1872,13 @@ def test_extract_outliers():
assert outliers2.shape[0] == shapeA[0] assert outliers2.shape[0] == shapeA[0]
assert outliers2.shape[1] == idx.numel() assert outliers2.shape[1] == idx.numel()
#print(outliers1) torch.testing.assert_allclose(outliers1, outliers2)
#print(outliers2)
CA, SA = F.transform(A, 'col_ampere')
outliers2 = F.extract_outliers(CA, SA, idx)
assert outliers2.shape[0] == shapeA[0]
assert outliers2.shape[1] == idx.numel()
torch.testing.assert_allclose(outliers1, outliers2) torch.testing.assert_allclose(outliers1, outliers2)