Initial 4-bit naive batch size 1, 81 vs 185.
This commit is contained in:
parent
e54d2730fc
commit
f89ff93e26
bitsandbytes
csrc
tests
|
@ -1503,7 +1503,7 @@ def cutlass3_gemm(
|
|||
ldc = ct.c_int32(ldc)
|
||||
|
||||
if B.dtype == torch.uint8:
|
||||
lib.cgemm_4bit_inference(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
||||
lib.cgemm_4bit_inference_naive(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
||||
elif A.dtype == torch.float32:
|
||||
lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc)
|
||||
elif A.dtype == torch.float16:
|
||||
|
|
162
csrc/kernels.cu
162
csrc/kernels.cu
|
@ -3088,7 +3088,7 @@ template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_l
|
|||
}
|
||||
}
|
||||
|
||||
#define WARPS 5
|
||||
#define WARPS 3
|
||||
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
|
||||
{
|
||||
|
||||
|
@ -3298,15 +3298,15 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
|||
}
|
||||
|
||||
|
||||
template <typename T> __device__ void printnonzero(T *A, int num_values)
|
||||
template <typename T> __device__ void printnonzero(T *A, int num_values, const char * strval)
|
||||
{
|
||||
for(int i = 0; i < num_values; i++)
|
||||
if((float)A[i] != 0.0)
|
||||
printf("%i %f\n", i, (float)A[i]);
|
||||
printf("%s %i %f\n", strval, i, (float)A[i]);
|
||||
}
|
||||
|
||||
template __device__ void printnonzero<float>(float *A, int num_values);
|
||||
template __device__ void printnonzero<half>(half *A, int num_values);
|
||||
template __device__ void printnonzero<float>(float *A, int num_values, const char*strval);
|
||||
template __device__ void printnonzero<half>(half *A, int num_values, const char*strval);
|
||||
|
||||
__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0};
|
||||
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
|
||||
|
@ -3315,6 +3315,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
|
|||
using namespace nvcuda;
|
||||
int col_offset = blockIdx.x *32;
|
||||
const int warp_id = threadIdx.x / 32;
|
||||
const int warp_idx = threadIdx.x % 32;
|
||||
const int half_warp_id = threadIdx.x / 16;
|
||||
const int half_warp_lane = threadIdx.x % 16;
|
||||
const int batch_size_warps = (WARPS-1)*2;
|
||||
|
@ -3324,23 +3325,30 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
|
|||
#pragma unroll 16
|
||||
for(int i = 0; i < 16; i++)
|
||||
quant_map[i] = nf4_data[i];
|
||||
//__shared__ T quant_map[16*160];
|
||||
|
||||
T local_A[2];
|
||||
T local_B[64];
|
||||
unsigned char local_B_4bit[32];
|
||||
|
||||
|
||||
const int a_tile_offset = 16;
|
||||
const int b_tile_offset = (16*32 + 16);
|
||||
|
||||
__shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))];
|
||||
__shared__ T smem_A[8*16 + (16*(batch_size_warps-1))];
|
||||
__shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))];
|
||||
//__shared__ T smem_C[8*32];
|
||||
__shared__ T smem_C[8*32];
|
||||
|
||||
wmma::fragment<wmma::matrix_a, 8, 32, 16, half, wmma::row_major> a_frag;
|
||||
wmma::fragment<wmma::matrix_b, 8, 32, 16, half, wmma::col_major> b_frag;
|
||||
wmma::fragment<wmma::accumulator, 8, 32, 16, half> c_frag;
|
||||
wmma::fill_fragment(c_frag, 0.0f);
|
||||
|
||||
for(int i = threadIdx.x; i < (8*32); i+=blockDim.x)
|
||||
smem_C[i] = 0.0f;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int ticktock = 0;
|
||||
int idx = 0 + threadIdx.x;
|
||||
int loaded_values = 0;
|
||||
|
@ -3366,8 +3374,17 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
|
|||
#pragma unroll 64
|
||||
for(int col = 0; col < 64; col+=2)
|
||||
{
|
||||
local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f);
|
||||
local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f);
|
||||
//local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f);
|
||||
//local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f);
|
||||
//local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0);
|
||||
//local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0);
|
||||
//local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0);
|
||||
//local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0);
|
||||
|
||||
//local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0);
|
||||
//local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0);
|
||||
local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(17.0);
|
||||
local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(17.0);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3391,13 +3408,17 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
|
|||
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
|
||||
}
|
||||
ticktock = ticktock == 0 ? 1 : 0;
|
||||
//if(threadIdx.x == 0)
|
||||
//printf("aa %i %i\n", idx, loaded_values);
|
||||
|
||||
//for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
|
||||
for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
|
||||
{
|
||||
idx = base_idx + threadIdx.x;
|
||||
//if(threadIdx.x == 0)
|
||||
//printf("%i %i\n", idx, loaded_values);
|
||||
|
||||
__syncthreads();
|
||||
//__syncthreads();
|
||||
if(idx < K && warp_id < (WARPS-1))
|
||||
{
|
||||
if(loaded_values == 0)
|
||||
|
@ -3425,11 +3446,17 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
|
|||
#pragma unroll 64
|
||||
for(int col = 0; col < 64; col+=2)
|
||||
{
|
||||
local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx);
|
||||
local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx);
|
||||
//local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx);
|
||||
//local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx);
|
||||
//local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx);
|
||||
//local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx);
|
||||
//local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx);
|
||||
//local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx);
|
||||
|
||||
//local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax);
|
||||
//local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax);
|
||||
local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx);
|
||||
local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx);
|
||||
}
|
||||
//printnonzero<T>(local_B, 128, "");
|
||||
}
|
||||
|
||||
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
|
||||
|
@ -3463,6 +3490,11 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
|
|||
}
|
||||
|
||||
__syncthreads();
|
||||
//if(threadIdx.x == 0)
|
||||
//{
|
||||
// printnonzero<T>(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: ");
|
||||
// printnonzero<T>(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: ");
|
||||
//}
|
||||
if(warp_id != (WARPS-1)){ return; }
|
||||
// only warp_id == (WARPS-1) from here
|
||||
int warp_lane = threadIdx.x % 32;
|
||||
|
@ -3470,6 +3502,8 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
|
|||
ticktock = ticktock == 0 ? 1 : 0;
|
||||
for(int k = 0; k < batch_size_warps; k++)
|
||||
{
|
||||
//if(warp_lane == 0)
|
||||
//printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x);
|
||||
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
|
||||
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
|
||||
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
|
||||
|
@ -3477,14 +3511,101 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
|
|||
|
||||
// 129 mu
|
||||
if(warp_id == (WARPS-1))
|
||||
wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major);
|
||||
wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major);
|
||||
|
||||
printnonzero<T>(smem_A, 32);
|
||||
//printnonzero<T>(smem_C, 32, "");
|
||||
|
||||
if(col_offset + warp_lane < M)
|
||||
out[col_offset + warp_lane] = smem_A[warp_lane];
|
||||
out[col_offset + warp_lane] = smem_C[warp_lane];
|
||||
}
|
||||
|
||||
#define num_values_4bit 16
|
||||
template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{
|
||||
|
||||
// per threadblock:
|
||||
// load step-by-step in chunks of [64,warps]: 1x64 * [64,warps] -> [1,warps]
|
||||
// 64 packed 8bit values per warp AND each warp loads one output for B -> 1x128 * 128x1
|
||||
// 4 warps -> 4 loads per iter
|
||||
// 1x128 * 128x4 -> 1x4 outputs
|
||||
typedef cub::WarpReduce<T> WarpReduce;
|
||||
__shared__ typename WarpReduce::TempStorage temp_storage[4];
|
||||
|
||||
const int warp_idx = threadIdx.x / 32;
|
||||
const int warp_lane = threadIdx.x % 32;
|
||||
const int row_B = 4*blockIdx.x + warp_idx;
|
||||
T local_C = T(0);
|
||||
|
||||
T quant_map[16];
|
||||
#pragma unroll 16
|
||||
for(int i = 0; i < 16; i++)
|
||||
quant_map[i] = nf4_data[i];
|
||||
|
||||
unsigned char local_B_4bit[num_values_4bit/2];
|
||||
T local_B[num_values_4bit];
|
||||
|
||||
// need to increase occupancy by splitting the rows, but can be done later
|
||||
|
||||
// A: [1, K]
|
||||
// B: [N, K]
|
||||
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit)
|
||||
{
|
||||
int offset_B = ldb*row_B + (inner_idx/2);
|
||||
int absidx = (2*offset_B)/blocksize;
|
||||
T local_absmax = __ldg(&(absmax[absidx]));
|
||||
|
||||
//printf("%f %i %i %i %i %i %i\n", (float)local_absmax, absidx, lda*row_B, K, ldb, row_B, offset_B);
|
||||
#pragma unroll
|
||||
for(int k = 0; k < num_values_4bit/2; k++)
|
||||
{
|
||||
if((inner_idx/2) < K && row_B < M)
|
||||
local_B_4bit[k] = B[offset_B + k];
|
||||
else
|
||||
local_B_4bit[k] = 0b01110111;
|
||||
}
|
||||
|
||||
|
||||
//if(row_B < M)
|
||||
//{
|
||||
// if((inner_idx/num_values_4bit) < K)
|
||||
// reinterpret_cast<int4*>(local_B_4bit)[0] = reinterpret_cast<int4*>(B)[offset_B/(num_values_4bit/2)];
|
||||
// else
|
||||
// {
|
||||
// for(int k = 0; k < num_values_4bit/2; k++)
|
||||
// {
|
||||
// if((inner_idx/2) < K && row_B < M)
|
||||
// local_B_4bit[k] = B[offset_B + k];
|
||||
// else
|
||||
// local_B_4bit[k] = 0b01110111;
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
|
||||
|
||||
#pragma unroll
|
||||
for(int k = 0; k < num_values_4bit; k++)
|
||||
{
|
||||
local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax;
|
||||
local_B[k*2+ 1] = quant_map[local_B_4bit[k] & 0x0F]*local_absmax;
|
||||
}
|
||||
|
||||
//printnonzero<T>(local_B, 4, "B values: ");
|
||||
|
||||
#pragma unroll
|
||||
for(int k = 0; k < num_values_4bit; k++)
|
||||
local_C += A[inner_idx + k]*local_B[k];
|
||||
|
||||
}
|
||||
|
||||
local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C);
|
||||
|
||||
if(row_B < M && warp_lane == 0)
|
||||
out[row_B] = local_C;
|
||||
|
||||
}
|
||||
|
||||
|
||||
//#define ROWS 2
|
||||
//template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc)
|
||||
//{
|
||||
|
@ -3647,8 +3768,15 @@ template __global__ void gemm_device<half, 16, 32>(int M, int N, int K, half * _
|
|||
template __global__ void gemm_device<half, 16, 64>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||
template __global__ void gemm_device<half, 16, 96>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||
|
||||
template __global__ void kgemm_4bit_inference<half, 96>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference<half, 256>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
template __global__ void kgemm_4bit_inference_naive<half, 96>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference_naive<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference_naive<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference_naive<half, 256>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
||||
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
||||
|
|
|
@ -106,6 +106,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ voi
|
|||
|
||||
template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n);
|
||||
|
||||
|
||||
__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n);
|
||||
|
||||
|
||||
|
@ -124,6 +125,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
|
|||
|
||||
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc);
|
||||
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
template <typename T, int FUNC> __global__ void kfunc(T *A, T *B, T value, long n);
|
||||
|
||||
|
|
24
csrc/ops.cu
24
csrc/ops.cu
|
@ -723,7 +723,28 @@ template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsi
|
|||
//cout << m << endl;
|
||||
//cout << n << endl;
|
||||
//cout << k << endl;
|
||||
kgemm_4bit_inference<T, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
kgemm_4bit_inference<T, 96><<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
//kgemm_4bit_inference<T, 256><<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
//kgemm_4bit_inference<T, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
}
|
||||
|
||||
template <typename T> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{
|
||||
|
||||
int num_blocks = (m+3)/4;
|
||||
|
||||
cout << num_blocks << endl;
|
||||
//cout << lda << endl;
|
||||
//cout << ldb << endl;
|
||||
//cout << ldc << endl;
|
||||
|
||||
//cout << m << endl;
|
||||
//cout << n << endl;
|
||||
//cout << k << endl;
|
||||
kgemm_4bit_inference_naive<T, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
//kgemm_4bit_inference<T, 256><<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
//kgemm_4bit_inference<T, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
}
|
||||
|
||||
|
@ -747,6 +768,7 @@ template void func<float, ARANGE>(float *A, float *B, float value, long n);
|
|||
template void func<float, _MUL>(float *A, float *B, float value, long n);
|
||||
|
||||
template void gemm_4bit_inference<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template void gemm_4bit_inference_naive<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
|
||||
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
|
||||
template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
|
||||
|
|
|
@ -200,6 +200,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows
|
|||
|
||||
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits);
|
||||
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template <typename T> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
template <typename T, int FUNC> void func(T *A, T *B, T value, long n);
|
||||
|
||||
|
|
|
@ -28,6 +28,9 @@ void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int l
|
|||
void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
|
||||
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \
|
||||
|
||||
|
@ -345,6 +348,9 @@ extern "C"
|
|||
void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
void cgemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
void *cget_managed_ptr(size_t bytes)
|
||||
{
|
||||
void *ptr;
|
||||
|
|
|
@ -1773,17 +1773,17 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
|
|||
print("partial matmul", time.time() - t0)
|
||||
|
||||
|
||||
batch_size = 1
|
||||
seqdim = 1
|
||||
batch_size = 32
|
||||
seqdim = 512+256
|
||||
values = []
|
||||
#values.append((batch_size, seqdim, 768, 4 * 768))
|
||||
#values.append((batch_size, seqdim, 1024, 4*1024))
|
||||
#values.append((batch_size, seqdim, 1536, 4*1536))
|
||||
#values.append((batch_size, seqdim, 2048, 4*2048))
|
||||
#values.append((batch_size, seqdim, 2560, 4*2560))
|
||||
values.append((batch_size, seqdim, 4096, 4*4096))
|
||||
values.append((batch_size, seqdim, 5120, 4*5120))
|
||||
values.append((batch_size, seqdim, 6656, 4*6656))
|
||||
#values.append((batch_size, seqdim, 4096, 4*4096))
|
||||
#values.append((batch_size, seqdim, 5120, 4*5120))
|
||||
#values.append((batch_size, seqdim, 6656, 4*6656))
|
||||
values.append((batch_size, seqdim, 8192, 4*8192))
|
||||
#values.append((batch_size, seqdim, 5140, 4*5140))
|
||||
#values.append((batch_size, seqdim, 12288, 4*12288))
|
||||
|
@ -1827,19 +1827,19 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
torch.cuda.synchronize()
|
||||
print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
|
||||
torch.cuda.synchronize()
|
||||
print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
|
||||
#torch.cuda.synchronize()
|
||||
#print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
|
||||
torch.cuda.synchronize()
|
||||
print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
|
||||
#torch.cuda.synchronize()
|
||||
#print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
|
@ -1901,21 +1901,21 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
#torch.cuda.synchronize()
|
||||
#print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
linear8bit(A)
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
linear8bit(A)
|
||||
torch.cuda.synchronize()
|
||||
print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
#linear8bit(A)
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# linear8bit(A)
|
||||
#torch.cuda.synchronize()
|
||||
#print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
linearMixedBit(A)
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
linearMixedBit(A)
|
||||
torch.cuda.synchronize()
|
||||
print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
#linearMixedBit(A)
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# linearMixedBit(A)
|
||||
#torch.cuda.synchronize()
|
||||
#print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
#linear8bit_train(A)
|
||||
#torch.cuda.synchronize()
|
||||
|
@ -2411,10 +2411,14 @@ def test_cutlass3_gemm(dtype):
|
|||
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
||||
def test_gemm_4bit(dtype):
|
||||
print('')
|
||||
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
||||
#for dim in [4096, 5120, 6656, 8192]:
|
||||
#for dim in [32]:
|
||||
for dim in [32]:
|
||||
for dim in [4096]:
|
||||
#for dim in [5120]:
|
||||
#for dim in [6656]:
|
||||
#for dim in [128]:
|
||||
errs = []
|
||||
relerrs = []
|
||||
max_err = 0
|
||||
|
@ -2424,24 +2428,36 @@ def test_gemm_4bit(dtype):
|
|||
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
|
||||
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
|
||||
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
|
||||
A = torch.randn(1, dim+0, dtype=dtype, device='cuda')
|
||||
B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||
A = torch.randn(1, dim+2, dtype=dtype, device='cuda')
|
||||
B = torch.randn(4*dim, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||
#B = torch.randn(1, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||
|
||||
#print('')
|
||||
#print(A)
|
||||
#print(B.t())
|
||||
#A[:, :-1] = 0
|
||||
#B[:, :-1] = 0
|
||||
#A.flatten()[:-1] = 0
|
||||
#B.flatten()[:-1] = 0
|
||||
|
||||
qB, state = F.quantize_nf4(B)
|
||||
F.dequantize_nf4(qB, state)
|
||||
|
||||
C3 = torch.matmul(A, B.t())
|
||||
#C3 = torch.matmul(A, B.t())
|
||||
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
|
||||
C1 = bnb.matmul_4bit(A, qB.t(), state)
|
||||
|
||||
print(C1)
|
||||
print(C2)
|
||||
#print(state)
|
||||
#print(qB)
|
||||
|
||||
|
||||
#print('')
|
||||
#print(A)
|
||||
#print(B)
|
||||
#print('='*89)
|
||||
#print(C1)
|
||||
#print(C2)
|
||||
#print(C3)
|
||||
|
||||
#print(C1.shape, C2.shape)
|
||||
|
||||
|
@ -2455,7 +2471,7 @@ def test_gemm_4bit(dtype):
|
|||
max_relerr = max(relerr.max(), max_relerr)
|
||||
err = err.mean().item()
|
||||
relerr = relerr.mean().item()
|
||||
print(err)
|
||||
#print(err)
|
||||
|
||||
errs.append(err)
|
||||
relerrs.append(relerr)
|
||||
|
@ -2463,20 +2479,20 @@ def test_gemm_4bit(dtype):
|
|||
if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
|
||||
print('')
|
||||
print(i, err, relerr)
|
||||
print(A.flatten()[-6:])
|
||||
print(B.flatten()[-6:])
|
||||
out = A.flatten()[-6:]*B.flatten()[-6:]
|
||||
print(out)
|
||||
print(out[:-1].sum())
|
||||
#print(A.flatten()[-6:])
|
||||
#print(B.flatten()[-6:])
|
||||
#out = A.flatten()[-6:]*B.flatten()[-6:]
|
||||
#print(out)
|
||||
#print(out[:-1].sum())
|
||||
print('='*80)
|
||||
print(C1.flatten()[-6:])
|
||||
print(C2.flatten()[-6:])
|
||||
#print(C1.flatten()[-6:])
|
||||
#print(C2.flatten()[-6:])
|
||||
#assert False, 'ERROR'
|
||||
|
||||
c = int(C1.numel()*0.0014*(dim/256))+1
|
||||
|
||||
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
|
||||
#print(c/math.sqrt(dim))
|
||||
print(c/math.sqrt(dim))
|
||||
print('')
|
||||
print(dim, sum(errs)/len(errs)/math.sqrt(dim))
|
||||
print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
|
||||
|
|
Loading…
Reference in New Issue
Block a user