Initial 4-bit naive batch size 1, 81 vs 185.

This commit is contained in:
Tim Dettmers 2023-07-03 18:45:38 -07:00
parent e54d2730fc
commit f89ff93e26
7 changed files with 240 additions and 65 deletions

View File

@ -1503,7 +1503,7 @@ def cutlass3_gemm(
ldc = ct.c_int32(ldc)
if B.dtype == torch.uint8:
lib.cgemm_4bit_inference(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
lib.cgemm_4bit_inference_naive(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
elif A.dtype == torch.float32:
lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc)
elif A.dtype == torch.float16:

View File

@ -3088,7 +3088,7 @@ template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_l
}
}
#define WARPS 5
#define WARPS 3
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
{
@ -3298,15 +3298,15 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
}
template <typename T> __device__ void printnonzero(T *A, int num_values)
template <typename T> __device__ void printnonzero(T *A, int num_values, const char * strval)
{
for(int i = 0; i < num_values; i++)
if((float)A[i] != 0.0)
printf("%i %f\n", i, (float)A[i]);
printf("%s %i %f\n", strval, i, (float)A[i]);
}
template __device__ void printnonzero<float>(float *A, int num_values);
template __device__ void printnonzero<half>(half *A, int num_values);
template __device__ void printnonzero<float>(float *A, int num_values, const char*strval);
template __device__ void printnonzero<half>(half *A, int num_values, const char*strval);
__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0};
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
@ -3315,6 +3315,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
using namespace nvcuda;
int col_offset = blockIdx.x *32;
const int warp_id = threadIdx.x / 32;
const int warp_idx = threadIdx.x % 32;
const int half_warp_id = threadIdx.x / 16;
const int half_warp_lane = threadIdx.x % 16;
const int batch_size_warps = (WARPS-1)*2;
@ -3324,23 +3325,30 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
#pragma unroll 16
for(int i = 0; i < 16; i++)
quant_map[i] = nf4_data[i];
//__shared__ T quant_map[16*160];
T local_A[2];
T local_B[64];
unsigned char local_B_4bit[32];
const int a_tile_offset = 16;
const int b_tile_offset = (16*32 + 16);
__shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))];
__shared__ T smem_A[8*16 + (16*(batch_size_warps-1))];
__shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))];
//__shared__ T smem_C[8*32];
__shared__ T smem_C[8*32];
wmma::fragment<wmma::matrix_a, 8, 32, 16, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, 8, 32, 16, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, 8, 32, 16, half> c_frag;
wmma::fill_fragment(c_frag, 0.0f);
for(int i = threadIdx.x; i < (8*32); i+=blockDim.x)
smem_C[i] = 0.0f;
__syncthreads();
int ticktock = 0;
int idx = 0 + threadIdx.x;
int loaded_values = 0;
@ -3366,8 +3374,17 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
#pragma unroll 64
for(int col = 0; col < 64; col+=2)
{
local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f);
local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f);
//local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f);
//local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f);
//local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0);
//local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0);
//local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0);
//local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0);
//local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0);
//local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0);
local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(17.0);
local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(17.0);
}
}
@ -3391,13 +3408,17 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
}
ticktock = ticktock == 0 ? 1 : 0;
//if(threadIdx.x == 0)
//printf("aa %i %i\n", idx, loaded_values);
//for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
{
idx = base_idx + threadIdx.x;
//if(threadIdx.x == 0)
//printf("%i %i\n", idx, loaded_values);
__syncthreads();
//__syncthreads();
if(idx < K && warp_id < (WARPS-1))
{
if(loaded_values == 0)
@ -3425,11 +3446,17 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
#pragma unroll 64
for(int col = 0; col < 64; col+=2)
{
local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx);
local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx);
//local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx);
//local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx);
//local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx);
//local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx);
//local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx);
//local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx);
//local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax);
//local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax);
local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx);
local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx);
}
//printnonzero<T>(local_B, 128, "");
}
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
@ -3463,6 +3490,11 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
}
__syncthreads();
//if(threadIdx.x == 0)
//{
// printnonzero<T>(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: ");
// printnonzero<T>(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: ");
//}
if(warp_id != (WARPS-1)){ return; }
// only warp_id == (WARPS-1) from here
int warp_lane = threadIdx.x % 32;
@ -3470,6 +3502,8 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
ticktock = ticktock == 0 ? 1 : 0;
for(int k = 0; k < batch_size_warps; k++)
{
//if(warp_lane == 0)
//printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x);
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
@ -3477,14 +3511,101 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
// 129 mu
if(warp_id == (WARPS-1))
wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major);
wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major);
printnonzero<T>(smem_A, 32);
//printnonzero<T>(smem_C, 32, "");
if(col_offset + warp_lane < M)
out[col_offset + warp_lane] = smem_A[warp_lane];
out[col_offset + warp_lane] = smem_C[warp_lane];
}
#define num_values_4bit 16
template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
{
// per threadblock:
// load step-by-step in chunks of [64,warps]: 1x64 * [64,warps] -> [1,warps]
// 64 packed 8bit values per warp AND each warp loads one output for B -> 1x128 * 128x1
// 4 warps -> 4 loads per iter
// 1x128 * 128x4 -> 1x4 outputs
typedef cub::WarpReduce<T> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage[4];
const int warp_idx = threadIdx.x / 32;
const int warp_lane = threadIdx.x % 32;
const int row_B = 4*blockIdx.x + warp_idx;
T local_C = T(0);
T quant_map[16];
#pragma unroll 16
for(int i = 0; i < 16; i++)
quant_map[i] = nf4_data[i];
unsigned char local_B_4bit[num_values_4bit/2];
T local_B[num_values_4bit];
// need to increase occupancy by splitting the rows, but can be done later
// A: [1, K]
// B: [N, K]
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit)
{
int offset_B = ldb*row_B + (inner_idx/2);
int absidx = (2*offset_B)/blocksize;
T local_absmax = __ldg(&(absmax[absidx]));
//printf("%f %i %i %i %i %i %i\n", (float)local_absmax, absidx, lda*row_B, K, ldb, row_B, offset_B);
#pragma unroll
for(int k = 0; k < num_values_4bit/2; k++)
{
if((inner_idx/2) < K && row_B < M)
local_B_4bit[k] = B[offset_B + k];
else
local_B_4bit[k] = 0b01110111;
}
//if(row_B < M)
//{
// if((inner_idx/num_values_4bit) < K)
// reinterpret_cast<int4*>(local_B_4bit)[0] = reinterpret_cast<int4*>(B)[offset_B/(num_values_4bit/2)];
// else
// {
// for(int k = 0; k < num_values_4bit/2; k++)
// {
// if((inner_idx/2) < K && row_B < M)
// local_B_4bit[k] = B[offset_B + k];
// else
// local_B_4bit[k] = 0b01110111;
// }
// }
//}
#pragma unroll
for(int k = 0; k < num_values_4bit; k++)
{
local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax;
local_B[k*2+ 1] = quant_map[local_B_4bit[k] & 0x0F]*local_absmax;
}
//printnonzero<T>(local_B, 4, "B values: ");
#pragma unroll
for(int k = 0; k < num_values_4bit; k++)
local_C += A[inner_idx + k]*local_B[k];
}
local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C);
if(row_B < M && warp_lane == 0)
out[row_B] = local_C;
}
//#define ROWS 2
//template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc)
//{
@ -3647,8 +3768,15 @@ template __global__ void gemm_device<half, 16, 32>(int M, int N, int K, half * _
template __global__ void gemm_device<half, 16, 64>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 96>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void kgemm_4bit_inference<half, 96>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference<half, 256>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference_naive<half, 96>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference_naive<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference_naive<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference_naive<half, 256>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);

View File

@ -106,6 +106,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ voi
template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n);
__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n);
@ -124,6 +125,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc);
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
template <typename T, int FUNC> __global__ void kfunc(T *A, T *B, T value, long n);

View File

@ -723,7 +723,28 @@ template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsi
//cout << m << endl;
//cout << n << endl;
//cout << k << endl;
kgemm_4bit_inference<T, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
kgemm_4bit_inference<T, 96><<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
//kgemm_4bit_inference<T, 256><<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
//kgemm_4bit_inference<T, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}
template <typename T> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
{
int num_blocks = (m+3)/4;
cout << num_blocks << endl;
//cout << lda << endl;
//cout << ldb << endl;
//cout << ldc << endl;
//cout << m << endl;
//cout << n << endl;
//cout << k << endl;
kgemm_4bit_inference_naive<T, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
//kgemm_4bit_inference<T, 256><<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
//kgemm_4bit_inference<T, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}
@ -747,6 +768,7 @@ template void func<float, ARANGE>(float *A, float *B, float value, long n);
template void func<float, _MUL>(float *A, float *B, float value, long n);
template void gemm_4bit_inference<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template void gemm_4bit_inference_naive<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);

View File

@ -200,6 +200,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits);
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
template <typename T> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
template <typename T, int FUNC> void func(T *A, T *B, T value, long n);

View File

@ -28,6 +28,9 @@ void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int l
void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
{ gemm_4bit_inference<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
{ gemm_4bit_inference_naive<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \
@ -345,6 +348,9 @@ extern "C"
void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
{ gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
void cgemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
void *cget_managed_ptr(size_t bytes)
{
void *ptr;

View File

@ -1773,17 +1773,17 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
print("partial matmul", time.time() - t0)
batch_size = 1
seqdim = 1
batch_size = 32
seqdim = 512+256
values = []
#values.append((batch_size, seqdim, 768, 4 * 768))
#values.append((batch_size, seqdim, 1024, 4*1024))
#values.append((batch_size, seqdim, 1536, 4*1536))
#values.append((batch_size, seqdim, 2048, 4*2048))
#values.append((batch_size, seqdim, 2560, 4*2560))
values.append((batch_size, seqdim, 4096, 4*4096))
values.append((batch_size, seqdim, 5120, 4*5120))
values.append((batch_size, seqdim, 6656, 4*6656))
#values.append((batch_size, seqdim, 4096, 4*4096))
#values.append((batch_size, seqdim, 5120, 4*5120))
#values.append((batch_size, seqdim, 6656, 4*6656))
values.append((batch_size, seqdim, 8192, 4*8192))
#values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 12288, 4*12288))
@ -1827,19 +1827,19 @@ def test_bench_matmul(batch, seq, model, hidden):
torch.cuda.synchronize()
print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
torch.cuda.synchronize()
print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
#torch.cuda.synchronize()
#print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
torch.cuda.synchronize()
print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
#torch.cuda.synchronize()
#print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
torch.cuda.synchronize()
t0 = time.time()
@ -1901,21 +1901,21 @@ def test_bench_matmul(batch, seq, model, hidden):
#torch.cuda.synchronize()
#print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
linear8bit(A)
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
linear8bit(A)
torch.cuda.synchronize()
print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#linear8bit(A)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# linear8bit(A)
#torch.cuda.synchronize()
#print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
linearMixedBit(A)
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
linearMixedBit(A)
torch.cuda.synchronize()
print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#linearMixedBit(A)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# linearMixedBit(A)
#torch.cuda.synchronize()
#print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#linear8bit_train(A)
#torch.cuda.synchronize()
@ -2411,10 +2411,14 @@ def test_cutlass3_gemm(dtype):
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
def test_gemm_4bit(dtype):
print('')
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
#for dim in [4096, 5120, 6656, 8192]:
#for dim in [32]:
for dim in [32]:
for dim in [4096]:
#for dim in [5120]:
#for dim in [6656]:
#for dim in [128]:
errs = []
relerrs = []
max_err = 0
@ -2424,24 +2428,36 @@ def test_gemm_4bit(dtype):
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
A = torch.randn(1, dim+0, dtype=dtype, device='cuda')
B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim)
A = torch.randn(1, dim+2, dtype=dtype, device='cuda')
B = torch.randn(4*dim, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim)
#B = torch.randn(1, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim)
#print('')
#print(A)
#print(B.t())
#A[:, :-1] = 0
#B[:, :-1] = 0
#A.flatten()[:-1] = 0
#B.flatten()[:-1] = 0
qB, state = F.quantize_nf4(B)
F.dequantize_nf4(qB, state)
C3 = torch.matmul(A, B.t())
#C3 = torch.matmul(A, B.t())
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
C1 = bnb.matmul_4bit(A, qB.t(), state)
print(C1)
print(C2)
#print(state)
#print(qB)
#print('')
#print(A)
#print(B)
#print('='*89)
#print(C1)
#print(C2)
#print(C3)
#print(C1.shape, C2.shape)
@ -2455,7 +2471,7 @@ def test_gemm_4bit(dtype):
max_relerr = max(relerr.max(), max_relerr)
err = err.mean().item()
relerr = relerr.mean().item()
print(err)
#print(err)
errs.append(err)
relerrs.append(relerr)
@ -2463,20 +2479,20 @@ def test_gemm_4bit(dtype):
if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
print('')
print(i, err, relerr)
print(A.flatten()[-6:])
print(B.flatten()[-6:])
out = A.flatten()[-6:]*B.flatten()[-6:]
print(out)
print(out[:-1].sum())
#print(A.flatten()[-6:])
#print(B.flatten()[-6:])
#out = A.flatten()[-6:]*B.flatten()[-6:]
#print(out)
#print(out[:-1].sum())
print('='*80)
print(C1.flatten()[-6:])
print(C2.flatten()[-6:])
#print(C1.flatten()[-6:])
#print(C2.flatten()[-6:])
#assert False, 'ERROR'
c = int(C1.numel()*0.0014*(dim/256))+1
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
#print(c/math.sqrt(dim))
print(c/math.sqrt(dim))
print('')
print(dim, sum(errs)/len(errs)/math.sqrt(dim))
print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))