Added fp32 compute type for gemv_4bit.
This commit is contained in:
parent
cef519c89e
commit
5fab673442
|
@ -1464,6 +1464,9 @@ def gemv_4bit(
|
|||
if state is None:
|
||||
raise ValueError(f'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )')
|
||||
|
||||
if A.numel() != A.shape[-1]:
|
||||
raise ValueError(f'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]')
|
||||
|
||||
Bshape = state[1]
|
||||
bout = Bshape[0]
|
||||
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = state
|
||||
|
@ -1474,90 +1477,17 @@ def gemv_4bit(
|
|||
|
||||
if out is None:
|
||||
if len(A.shape) == 3:
|
||||
out = torch.zeros(size=(A.shape[0], A.shape[1], bout), dtype=A.dtype, device=A.device)
|
||||
out = torch.empty(size=(A.shape[0], A.shape[1], bout), dtype=A.dtype, device=A.device)
|
||||
else:
|
||||
out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device)
|
||||
out = torch.empty(size=(A.shape[0], bout), dtype=A.dtype, device=A.device)
|
||||
|
||||
|
||||
|
||||
|
||||
sA = A.shape
|
||||
sB = B.shape
|
||||
if transposed_A and len(sA) == 2:
|
||||
sA = (sA[1], sA[0])
|
||||
elif transposed_A and len(sA) == 3:
|
||||
sA = (sA[0], sA[2], sA[0])
|
||||
if transposed_B and len(sB) == 2:
|
||||
sB = (sB[1], sB[0])
|
||||
elif transposed_B and len(sB) == 3:
|
||||
sB = (sB[0], sB[2], sB[0])
|
||||
# this is a mess: cuBLAS expect column major, but PyTorch is row major.
|
||||
# So to perform the matrix multiplication, we have to treat A, B, and C matrices
|
||||
# (transpose of row major is column major)
|
||||
# This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these
|
||||
|
||||
# matrices in the input arguments for cuBLAS
|
||||
# column major: A @ B = C: [m, k] @ [k, n] = [m, n]
|
||||
# row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
|
||||
# column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
|
||||
if len(sB) == 2:
|
||||
if B.stride()[0] == B.shape[1]:
|
||||
transposed_B = False
|
||||
elif B.stride()[1] == B.shape[0]:
|
||||
transposed_B = True
|
||||
if len(A.shape) == 2:
|
||||
if A.stride()[0] == A.shape[1]:
|
||||
transposed_A = False
|
||||
elif A.stride()[1] == A.shape[0]:
|
||||
transposed_A = True
|
||||
else:
|
||||
if A.stride()[1] == A.shape[2]:
|
||||
transposed_A = False
|
||||
elif A.stride()[2] == A.shape[1]:
|
||||
transposed_A = True
|
||||
|
||||
if len(sA) == 2:
|
||||
n = sA[0]
|
||||
ldb = A.stride()[1 if transposed_A else 0]
|
||||
elif len(sA) == 3 and len(sB) == 2:
|
||||
n = sA[0] * sA[1]
|
||||
ldb = sA[2]
|
||||
|
||||
m = sB[1]
|
||||
k = sB[0]
|
||||
lda = B.stride()[0]
|
||||
ldc = sB[1]
|
||||
elif len(sB) == 3:
|
||||
# special case
|
||||
assert len(sA) == 3
|
||||
if not (sA[0] == sB[0] and sA[1] == sB[1]):
|
||||
raise ValueError(
|
||||
f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}"
|
||||
)
|
||||
|
||||
transposed_A = True
|
||||
transposed_B = False
|
||||
|
||||
m = sB[2]
|
||||
n = sA[2]
|
||||
k = sB[0] * sB[1]
|
||||
|
||||
lda = n
|
||||
ldb = sA[2]
|
||||
ldc = m
|
||||
|
||||
# B^T @ A^T = C^T
|
||||
# [km, nk -> mn]
|
||||
#lda = ldb = ldc = 1
|
||||
#lda = 1
|
||||
if state is not None:
|
||||
n = 1
|
||||
m = Bshape[0]
|
||||
k = Bshape[1]
|
||||
lda = Bshape[0]
|
||||
ldc = Bshape[0]
|
||||
ldb = (ldb+1)//2
|
||||
#print(m, n, k, lda, ldb, ldc)
|
||||
is_on_gpu([B, A, out])
|
||||
ldb = (A.shape[-1]+1)//2
|
||||
is_on_gpu([B, A, out, absmax, state[-1]])
|
||||
m = ct.c_int32(m)
|
||||
n = ct.c_int32(n)
|
||||
k = ct.c_int32(k)
|
||||
|
@ -1570,6 +1500,8 @@ def gemv_4bit(
|
|||
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
||||
elif A.dtype == torch.bfloat16:
|
||||
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
||||
elif A.dtype == torch.float32:
|
||||
lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
||||
else:
|
||||
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
|
||||
else:
|
||||
|
|
|
@ -3520,7 +3520,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
|
|||
}
|
||||
|
||||
#define num_values_4bit 32
|
||||
template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize)
|
||||
template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{
|
||||
|
||||
// per threadblock:
|
||||
|
@ -3528,7 +3528,6 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
|
|||
// 64 packed 8bit values per warp AND each warp loads one output for B -> 1x128 * 128x1
|
||||
// 4 warps -> 4 loads per iter
|
||||
// 1x128 * 128x4 -> 1x4 outputs
|
||||
//typedef cub::WarpReduce<T> WarpReduce;
|
||||
typedef cub::WarpReduce<float> WarpReduce;
|
||||
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32];
|
||||
|
||||
|
@ -3536,7 +3535,6 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
|
|||
const int warp_lane = threadIdx.x % 32;
|
||||
const int row_B = (THREADS/32)*blockIdx.x + warp_idx;
|
||||
const int num_values_8bit = num_values_4bit/2;
|
||||
//T local_C = T(0.0f);
|
||||
float local_C = 0.0f;
|
||||
|
||||
unsigned char local_B_4bit[num_values_8bit];
|
||||
|
@ -3585,10 +3583,24 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
|
|||
|
||||
if(inner_idx+num_values_4bit)
|
||||
{
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_8bit/2) + 0];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[1] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_8bit/2) + 1];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[2] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_8bit/2) + 2];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[3] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_8bit/2) + 3];
|
||||
if(BITS==16)
|
||||
{
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/4) + 0];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[1] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/4) + 1];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[2] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/4) + 2];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[3] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/4) + 3];
|
||||
}
|
||||
else
|
||||
{
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/8) + 0];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[1] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/8) + 1];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[2] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/8) + 2];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[3] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/8) + 3];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[4] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/8) + 4];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[5] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/8) + 5];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[6] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/8) + 6];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[7] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/8) + 7];
|
||||
}
|
||||
|
||||
}
|
||||
else
|
||||
|
@ -3776,8 +3788,9 @@ template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, ha
|
|||
template __global__ void kgemm_4bit_inference<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference<half, 256>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
template __global__ void kgemm_4bit_inference_naive<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference_naive<half, 128, 16>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128, 16>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference_naive<float, 128, 32>(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
||||
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
||||
|
|
|
@ -125,7 +125,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
|
|||
|
||||
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc);
|
||||
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
template <typename T, int FUNC> __global__ void kfunc(T *A, T *B, T value, long n);
|
||||
|
||||
|
|
10
csrc/ops.cu
10
csrc/ops.cu
|
@ -729,12 +729,12 @@ template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsi
|
|||
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
}
|
||||
|
||||
template <typename T> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize)
|
||||
template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{
|
||||
|
||||
int num_blocks = (m+3)/4;
|
||||
|
||||
kgemm_4bit_inference_naive<T, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
|
||||
kgemm_4bit_inference_naive<T, 128, BITS><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
|
||||
}
|
||||
|
||||
template <typename T, int FUNC> void func(T *A, T *B, T value, long n)
|
||||
|
@ -757,8 +757,10 @@ template void func<float, ARANGE>(float *A, float *B, float value, long n);
|
|||
template void func<float, _MUL>(float *A, float *B, float value, long n);
|
||||
|
||||
template void gemm_4bit_inference<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template void gemm_4bit_inference_naive<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template void gemm_4bit_inference_naive<__nv_bfloat16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template void gemm_4bit_inference_naive<half, 16>(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template void gemm_4bit_inference_naive<float, 32>(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
|
||||
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
|
||||
template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
|
||||
|
|
|
@ -200,7 +200,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows
|
|||
|
||||
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits);
|
||||
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template <typename T> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
template <typename T, int FUNC> void func(T *A, T *B, T value, long n);
|
||||
|
||||
|
|
|
@ -29,10 +29,13 @@ void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, floa
|
|||
{ gemm_4bit_inference<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive<half>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); }
|
||||
{ gemm_4bit_inference_naive<half, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive<__nv_bfloat16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); }
|
||||
{ gemm_4bit_inference_naive<__nv_bfloat16, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
void gemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive<float, 32>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
|
||||
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \
|
||||
|
@ -400,6 +403,9 @@ extern "C"
|
|||
void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
void cgemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
#endif
|
||||
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); }
|
||||
|
|
|
@ -1776,7 +1776,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
|
|||
print("partial matmul", time.time() - t0)
|
||||
|
||||
|
||||
batch_size = 5
|
||||
batch_size = 1
|
||||
seqdim = 1
|
||||
values = []
|
||||
#values.append((batch_size, seqdim, 768, 4 * 768))
|
||||
|
@ -1793,7 +1793,7 @@ values.append((batch_size, seqdim, 6656, 4*6656))
|
|||
names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values]
|
||||
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
|
||||
def test_bench_matmul(batch, seq, model, hidden):
|
||||
iters = 80
|
||||
iters = 1000
|
||||
formatB = F.get_special_format_str()
|
||||
|
||||
A = torch.randn(batch, seq, model, device="cuda").half()
|
||||
|
@ -2361,9 +2361,7 @@ def test_normal_map_tree():
|
|||
|
||||
@pytest.mark.parametrize("double_quant", [True, False], ids=['DQ_True', 'DQ_False'])
|
||||
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
|
||||
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=['fp16', 'bf16'])
|
||||
#@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16'])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
|
||||
def test_gemv_4bit(dtype, storage_type, double_quant):
|
||||
print('')
|
||||
for dim in [128, 256, 512, 1024, 2048, 4096]:
|
||||
|
|
Loading…
Reference in New Issue
Block a user