Removed storage() from get_ptr; added boilerplate for bias dequant_mm.
This commit is contained in:
parent
26efb154c8
commit
1ed2fa2f21
|
@ -218,7 +218,7 @@ def get_ptr(A: Tensor) -> ct.c_void_p:
|
|||
if A is None:
|
||||
return None
|
||||
else:
|
||||
return ct.c_void_p(A.data.storage().data_ptr())
|
||||
return ct.c_void_p(A.data.data_ptr())
|
||||
|
||||
|
||||
def pre_call(device):
|
||||
|
@ -1407,8 +1407,10 @@ def mm_dequant(
|
|||
out=None,
|
||||
new_row_stats=None,
|
||||
new_col_stats=None,
|
||||
bias=None
|
||||
):
|
||||
assert A.dtype == torch.int32
|
||||
if bias is not None: assert bias.dtype == torch.float16
|
||||
out_shape = quant_state[0]
|
||||
if len(out_shape) == 3:
|
||||
out_shape = (out_shape[0] * out_shape[1], out_shape[2])
|
||||
|
@ -1430,17 +1432,20 @@ def mm_dequant(
|
|||
new_col_stats.shape[0] == col_stats.shape[0]
|
||||
), f"{new_col_stats.shape} vs {col_stats.shape}"
|
||||
|
||||
prev_device = pre_call(A.device)
|
||||
ptrA = get_ptr(A)
|
||||
ptrOut = get_ptr(out)
|
||||
ptrRowStats = get_ptr(row_stats)
|
||||
ptrColStats = get_ptr(col_stats)
|
||||
ptrNewRowStats = get_ptr(new_row_stats)
|
||||
ptrNewColStats = get_ptr(new_col_stats)
|
||||
ptrBias = get_ptr(bias)
|
||||
numRows = ct.c_int32(out_shape[0])
|
||||
numCols = ct.c_int32(out_shape[1])
|
||||
|
||||
is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats])
|
||||
lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, numRows, numCols)
|
||||
is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias])
|
||||
lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols)
|
||||
post_call(prev_device)
|
||||
|
||||
return out
|
||||
|
||||
|
|
|
@ -1889,7 +1889,7 @@ template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 1>(half * __rest
|
|||
|
||||
#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f)
|
||||
|
||||
template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n)
|
||||
template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n)
|
||||
{
|
||||
|
||||
// Strategy: To dequantize we need to load col/row statistics. This can be very expensive
|
||||
|
@ -2675,7 +2675,7 @@ template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(
|
|||
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
|
||||
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
|
||||
|
||||
template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n);
|
||||
template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n);
|
||||
|
||||
template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
|
||||
template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
|
||||
|
|
|
@ -111,7 +111,7 @@ template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_s
|
|||
|
||||
template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(
|
||||
int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats,
|
||||
half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n);
|
||||
half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n);
|
||||
|
||||
template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
|
||||
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
|
||||
|
|
|
@ -435,7 +435,7 @@ int fill_up_to_nearest_multiple(int value, int multiple)
|
|||
return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple)));
|
||||
}
|
||||
|
||||
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols)
|
||||
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half *bias, int numRows, int numCols)
|
||||
{
|
||||
int threads = 512;
|
||||
int tileCols = fill_up_to_nearest_multiple(numCols, 32);
|
||||
|
@ -447,7 +447,7 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out,
|
|||
num_blocks = num_blocks*(tileCols/32);
|
||||
assert(threads <= tilesize);
|
||||
|
||||
kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n);
|
||||
kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
@ -465,7 +465,6 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r
|
|||
col_tiles = col_tiles > 0 ? col_tiles : 1;
|
||||
int num_blocks = row_tiles * col_tiles;
|
||||
|
||||
|
||||
if(nnz_threshold == 0.0)
|
||||
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
|
||||
else if(nnz_threshold != 0.0)
|
||||
|
|
|
@ -163,7 +163,7 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle
|
|||
|
||||
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2);
|
||||
void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
|
||||
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols);
|
||||
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols);
|
||||
void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols);
|
||||
void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed,
|
||||
int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols);
|
||||
|
|
|
@ -248,8 +248,8 @@ extern "C"
|
|||
MAKE_FUNC_CTRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8)
|
||||
MAKE_FUNC_CTRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32)
|
||||
|
||||
void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols)
|
||||
{ dequant_mm_int32_fp16(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols); }
|
||||
void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols)
|
||||
{ dequant_mm_int32_fp16(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols); }
|
||||
void cget_col_row_stats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols)
|
||||
{ getColRowStats(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols); }
|
||||
|
||||
|
|
|
@ -961,20 +961,24 @@ dim4 = torch.randint(64, 1024, size=(n,)).tolist()
|
|||
dims = (2,)
|
||||
# ldb = list(range(256, 1*1024, 256))
|
||||
formatB = ["col_turing", "col_ampere"]
|
||||
values = list(product(dim1, dim4, dims, formatB))
|
||||
has_bias = [True, False]
|
||||
values = list(product(dim1, dim4, dims, formatB, has_bias))
|
||||
names = [
|
||||
"dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}".format(*vals) for vals in values
|
||||
"dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}_has_bias_{4}".format(*vals) for vals in values
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dim1, dim4, dims, formatB", values, ids=names)
|
||||
def test_dequant_mm(dim1, dim4, dims, formatB):
|
||||
@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names)
|
||||
def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
|
||||
inner = torch.randint(1, 128, size=(1,)).item()
|
||||
bias = None
|
||||
if has_bias: bias = torch.randn(dim4, device='cuda', dtype=torch.float16)
|
||||
formatB = F.get_special_format_str()
|
||||
for i in range(k):
|
||||
A = torch.randn(dim1, inner, device="cuda")
|
||||
B = torch.randn(dim4, inner, device="cuda")
|
||||
C1 = torch.matmul(A.half(), B.t().half())
|
||||
if has_bias: C1 += bias
|
||||
|
||||
A1, maxA = F.vectorwise_quant(A, dim=1)
|
||||
B1, maxB = F.vectorwise_quant(B, dim=1)
|
||||
|
@ -985,17 +989,15 @@ def test_dequant_mm(dim1, dim4, dims, formatB):
|
|||
|
||||
C3, S = F.nvidia_transform(C2, "row", state=SC)
|
||||
C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
|
||||
if has_bias: C4 += bias
|
||||
|
||||
count = (torch.isclose(C1, C4, atol=0.01, rtol=0.1) == 0).sum().item()
|
||||
n = C1.numel()
|
||||
p = 0.06
|
||||
assert (
|
||||
count / n < p
|
||||
), f"error in more than {p} of elements: {count}/{n}={count/n}"
|
||||
assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"
|
||||
|
||||
C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten())
|
||||
C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias)
|
||||
torch.testing.assert_allclose(C5, C4)
|
||||
# print(C2)
|
||||
|
||||
|
||||
n = 2
|
||||
|
|
Loading…
Reference in New Issue
Block a user