4-bit draft.

This commit is contained in:
Tim Dettmers 2023-04-29 21:52:47 -07:00
parent cad839941b
commit 21723f796a
7 changed files with 273 additions and 27 deletions

View File

@ -1380,10 +1380,15 @@ def cutlass3_gemm(
out: Tensor = None,
transposed_A=False,
transposed_B=False,
state=None
):
sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
if state is None:
Bshape = B.shape
else:
Bshape = state[1]
if out is None:
out = torch.zeros(size=sout, dtype=A.dtype, device=A.device)
out = torch.zeros(size=(A.shape[0], Bshape[1]), dtype=A.dtype, device=A.device)
sA = A.shape
sB = B.shape
@ -1456,7 +1461,13 @@ def cutlass3_gemm(
# [km, nk -> mn]
#lda = ldb = ldc = 1
#lda = 1
#print(m, n, k, lda, ldb, ldc)
if state is not None:
m = Bshape[0]
k = Bshape[1]
lda = Bshape[1]
ldc = Bshape[0]
ldb = (ldb+1)//2
print(m, n, k, lda, ldb, ldc)
is_on_gpu([B, A, out])
m = ct.c_int32(m)
n = ct.c_int32(n)
@ -1464,7 +1475,10 @@ def cutlass3_gemm(
lda = ct.c_int32(lda)
ldb = ct.c_int32(ldb)
ldc = ct.c_int32(ldc)
if A.dtype == torch.float32:
if B.dtype == torch.uint8:
lib.cgemm_4bit_inference(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
elif A.dtype == torch.float32:
lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc)
elif A.dtype == torch.float16:
lib.cgemm_host_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc)

View File

@ -69,6 +69,27 @@ __device__ float dDequantizeFP4(unsigned char val, float absmax)
}
}
__device__ float d2DequantizeFP4(unsigned char val)
{
float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
if((val & 0b0110) == 0)
{
// subnormal
if((val & 0b0001) == 0)
return 0.0f;
else
return sign*0.0625f;
}
else
{
// normal
float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f);
float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f;
return sign*exponent*fraction;
}
}
__device__ float dDequantizeFP4Tree(unsigned char val, float absmax)
{
float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
@ -145,7 +166,7 @@ __device__ unsigned char dQuantizeFP4(float x)
return 0b0000+sign;
}
__device__ float dDequantizeNF4(unsigned char val, float absmax)
__device__ half dhDequantizeNF4(unsigned char val)
{
// the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py
@ -153,49 +174,103 @@ __device__ float dDequantizeNF4(unsigned char val, float absmax)
if((val & 0b0100) == 4) // 1
if((val & 0b0010) == 2) // 11
if((val & 0b0001) == 1) // 111
return 1.0f*absmax;
return 1.0f;
else
return 0.7229568362236023f*absmax;
return 0.7229568362236023f;
else
if((val & 0b0001) == 1) // 110
return 0.5626170039176941f*absmax;
return 0.5626170039176941f;
else
return 0.44070982933044434f*absmax;
return 0.44070982933044434f;
else
if((val & 0b0010) == 2) //10
if((val & 0b0001) == 1) // 101
return 0.33791524171829224f*absmax;
return 0.33791524171829224f;
else
return 0.24611230194568634f*absmax;
return 0.24611230194568634f;
else
if((val & 0b0001) == 1) // 100
return 0.16093020141124725f*absmax;
return 0.16093020141124725f;
else
return 0.07958029955625534f*absmax;
return 0.07958029955625534f;
else
if((val & 0b0100) == 4) // 0
if((val & 0b0010) == 2) //01
if((val & 0b0001) == 1) // 011
return 0.0f*absmax;
return 0.0f;
else
return -0.09105003625154495f*absmax;
return -0.09105003625154495f;
else
if((val & 0b0001) == 1) // 010
return -0.18477343022823334f*absmax;
return -0.18477343022823334f;
else
return -0.28444138169288635f*absmax;
return -0.28444138169288635f;
else
if((val & 0b0010) == 2) //00
if((val & 0b0001) == 1) // 001
return -0.39491748809814453f*absmax;
return -0.39491748809814453f;
else
return -0.5250730514526367f*absmax;
return -0.5250730514526367f;
else
if((val & 0b0001) == 1) // 000
return -0.6961928009986877f*absmax;
return -0.6961928009986877f;
else
return -1.0f*absmax;
return -1.0f;
}
__device__ float dDequantizeNF4(unsigned char val)
{
// the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py
if((val & 0b1000) == 8)
if((val & 0b0100) == 4) // 1
if((val & 0b0010) == 2) // 11
if((val & 0b0001) == 1) // 111
return 1.0f;
else
return 0.7229568362236023f;
else
if((val & 0b0001) == 1) // 110
return 0.5626170039176941f;
else
return 0.44070982933044434f;
else
if((val & 0b0010) == 2) //10
if((val & 0b0001) == 1) // 101
return 0.33791524171829224f;
else
return 0.24611230194568634f;
else
if((val & 0b0001) == 1) // 100
return 0.16093020141124725f;
else
return 0.07958029955625534f;
else
if((val & 0b0100) == 4) // 0
if((val & 0b0010) == 2) //01
if((val & 0b0001) == 1) // 011
return 0.0f;
else
return -0.09105003625154495f;
else
if((val & 0b0001) == 1) // 010
return -0.18477343022823334f;
else
return -0.28444138169288635f;
else
if((val & 0b0010) == 2) //00
if((val & 0b0001) == 1) // 001
return -0.39491748809814453f;
else
return -0.5250730514526367f;
else
if((val & 0b0001) == 1) // 000
return -0.6961928009986877f;
else
return -1.0f;
}
@ -800,8 +875,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
{
vals[j*2] = dDequantizeNF4(qvals[j] >> 4, local_abs_max);
vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F, local_abs_max);
vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max;
vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max;
}
break;
}
@ -2947,7 +3022,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
//// 9. write outputs to matmul output matrix
//}
template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_load(T *local, T * __restrict__ const buffer, int idx, int limit_base, int limit)
template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_load(T *local, T * __restrict__ const buffer, int idx, int limit_base, int limit, float zero_value = 0.0f)
{
if(limit_base + ITEMS <= limit)
reinterpret_cast<TCAST*>(local)[0] = reinterpret_cast<TCAST*>(buffer)[idx/ITEMS];
@ -2958,7 +3033,7 @@ template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_l
if(limit_base + k < limit)
local[k] = buffer[idx+k];
else
local[k] = 0.0f;
local[k] = (T)zero_value;
}
}
}
@ -3024,6 +3099,109 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
out[col_offset + threadIdx.x ] = smem_C[threadIdx.x];
}
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
{
typedef cub::BlockReduce<T, THREADS> BlockReduce;
__shared__ typename BlockReduce::TempStorage reduce;
int col_offset = blockIdx.x *8;
T local_A[32];
unsigned char local_B_4bit[16];
T local_B[32];
T local_C[8];
__shared__ T smem_C[8];
if(threadIdx.x < 8)
smem_C[threadIdx.x] = T(0);
__syncthreads();
#pragma unroll 8
for(int k = 0; k < 8; k++)
local_C[k] = T(0);
for(int idx = threadIdx.x*32; idx < K; idx+=blockDim.x*32)
{
// we load only 8 values per iteration from A, so we
// need to do 4 loads for every single load from B
// for B, we have packed values, so the 16 8-bit values
// turn into 32 4-bit values to 4x 4 loads turns into 4x 8 loads
vector_load<T, int4, 8>(local_A, A, idx, idx, K);
vector_load<T, int4, 8>(&(local_A[8]), A, idx+8, idx+8, K);
vector_load<T, int4, 8>(&(local_A[16]), A, idx+16, idx+16, K);
vector_load<T, int4, 8>(&(local_A[24]), A, idx+24, idx+24, K);
for(int col = 0; col < 8; col++)
{
if((col + col_offset) >= M){ break; }
int offset_B = (col_offset+col)*ldb;
// 0111 -> 0.0f in NF4
// since we have packed 8-bits, we need cat(0b0111, 0b0111) = 0b01110111
vector_load<unsigned char, int4, 16>(local_B_4bit, B, (offset_B+idx+1)/2, (idx+1)/2, (K+1)/2, 0b01110111);
int absidx = (idx + offset_B)/blocksize;
half local_absmax = __ldg(&(absmax[absidx]));
//for(int k = 0; k < 16; k++)
//printf("%i %i ", local_B_4bit[k] >> 4, local_B_4bit[k] & 0x0F);
//printf("\n");
//vector_load<T, int4, 8>(local_A, A, idx, idx, K);
#pragma unroll 16
for(int k = 0; k < 16; k++)
{
//if(local_B_4bit[k ] != 0b01110111)
//printf("(%i %i %i) %i -> %f, %i -> %f\n", threadIdx.x , k, K, local_B_4bit[k ] >> 4, dDequantizeNF4(local_B_4bit[k ] >> 4, local_absmax),
//local_B_4bit[k ] & 0x0F, dDequantizeNF4(local_B_4bit[k ] & 0x0F, local_absmax));
//local_B[k*2] = d2DequantizeFP4(local_B_4bit[k] >> 4);//*local_absmax;
//local_B[k*2 + 1] = d2DequantizeFP4(local_B_4bit[k] & 0x0F);//*local_absmax;
local_B[k*2] = (half)(local_B_4bit[k] >> 4)*local_absmax;
local_B[k*2 + 1] = (half)(local_B_4bit[k] & 0x0F)*local_absmax;
//local_B[k*2] = (half)dDequantizeNF4(local_B_4bit[k ] >> 4);//*local_absmax;
//local_B[k*2 + 1] = (half)dDequantizeNF4(local_B_4bit[k ] & 0x0F);//*local_absmax;
}
#pragma unroll 32
//for(int k = 0; k < 8; k++)
for(int k = 0; k < 32; k++)
{
local_C[col] += local_A[k]*local_B[k];
//if((float)local_A[k] != 0.0 && (float)local_B[k] != 0.0)
//if((float)local_B[k] != 0.0)
//printf("%i %i %i %i %f*%f\n", threadIdx.x, k, col, (float)local_A[k], (float)local_B[k]);
}
}
}
#pragma unroll 8
for(int k = 0; k < 8; k++)
{
local_C[k] = BlockReduce(reduce).Reduce(local_C[k], cub::Sum());
__syncthreads();
}
if(threadIdx.x == 0)
{
#pragma unroll 8
for(int k = 0; k < 8; k++)
smem_C[k] = local_C[k];
}
else if(threadIdx.x >= 32)
// early return for unused warps
return;
__syncwarp();
if(threadIdx.x < 8 && col_offset + threadIdx.x < M)
out[col_offset + threadIdx.x ] = smem_C[threadIdx.x];
}
//#define ROWS 2
//template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc)
//{
@ -3207,6 +3385,8 @@ template __global__ void gemm_device<half, 32, 128>(int M, int N, int K, half *
template __global__ void gemm_device<float, 32, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
//template __global__ void kMatmul_inference_4bit<NF4, half, half, half>(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB);
template __global__ void with_staging_unified<2>(float const* global_in, float * global_out, size_t size, size_t batch_sz);

View File

@ -139,5 +139,6 @@ template <size_t stages_count /* Pipeline with stages_count stages */>
__global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz);
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc);
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
#endif

View File

@ -695,10 +695,28 @@ template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out
gemm_device<T, 16, 128><<< num_blocks, dimBlock, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
}
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
{
dim3 dimBlock(128);
int num_blocks = (m+7)/8;
cout << num_blocks << endl;
cout << lda << endl;
cout << ldb << endl;
cout << ldc << endl;
cout << m << endl;
cout << n << endl;
cout << k << endl;
kgemm_4bit_inference<T, 128><<< num_blocks, dimBlock, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
template void gemm_4bit_inference<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);

View File

@ -191,6 +191,7 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB);
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits);
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
void pipeline_test(float *A, float *B, size_t n, size_t batch_size);

View File

@ -25,6 +25,9 @@ void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, in
void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc)
{ gemm_host<half>(M, N, K, A, B, out, lda, ldb, ldc, 16); }
void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
{ gemm_4bit_inference<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
#define MAKE_FUNC32(fname, oname, gtype, gbits) \
void fname##32bit_g##gbits(gtype *g, gtype *p, \
@ -319,6 +322,9 @@ extern "C"
void cgemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc)
{ gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); }
void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
{ gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
#endif
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); }

View File

@ -2352,8 +2352,8 @@ def test_normal_map_tree():
print(pivots)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
#@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
def test_cutlass3_gemm(dtype):
for i in range(1):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
@ -2373,6 +2373,32 @@ def test_cutlass3_gemm(dtype):
torch.testing.assert_close(C1, C2, atol=1e-05, rtol=0.005)
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
def test_gemm_4bit(dtype):
for i in range(1):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#torch.random.manual_seed(17)
A = torch.rand(1, 4096, dtype=dtype, device='cuda')
B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
#print('')
#print(A)
#print(B)
qB, state = F.quantize_nf4(B)
F.dequantize_nf4(qB, state)
C1 = torch.matmul(A, B.t())
#C1 = bnb.matmul_4bit(A, qB.t(), state)
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
#print(C1)
#print(C2)
#torch.testing.assert_close(C1, C2, atol=1e-5, rtol=0.005)
def test_pipeline_func():
a = torch.rand(2, 4).cuda()