From f89ff93e26d02037db30e88053983d6bb12dd660 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 3 Jul 2023 18:45:38 -0700 Subject: [PATCH] Initial 4-bit naive batch size 1, 81 vs 185. --- bitsandbytes/functional.py | 2 +- csrc/kernels.cu | 162 +++++++++++++++++++++++++++++++++---- csrc/kernels.cuh | 2 + csrc/ops.cu | 24 +++++- csrc/ops.cuh | 1 + csrc/pythonInterface.c | 6 ++ tests/test_functional.py | 108 ++++++++++++++----------- 7 files changed, 240 insertions(+), 65 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index afa346e..3ae4237 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1503,7 +1503,7 @@ def cutlass3_gemm( ldc = ct.c_int32(ldc) if B.dtype == torch.uint8: - lib.cgemm_4bit_inference(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) + lib.cgemm_4bit_inference_naive(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) elif A.dtype == torch.float32: lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) elif A.dtype == torch.float16: diff --git a/csrc/kernels.cu b/csrc/kernels.cu index ea0be06..216d436 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3088,7 +3088,7 @@ template __device__ inline void vector_l } } -#define WARPS 5 +#define WARPS 3 template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) { @@ -3298,15 +3298,15 @@ template __global__ void gemm_device(int M, } -template __device__ void printnonzero(T *A, int num_values) +template __device__ void printnonzero(T *A, int num_values, const char * strval) { for(int i = 0; i < num_values; i++) if((float)A[i] != 0.0) - printf("%i %f\n", i, (float)A[i]); + printf("%s %i %f\n", strval, i, (float)A[i]); } -template __device__ void printnonzero(float *A, int num_values); -template __device__ void printnonzero(half *A, int num_values); +template __device__ void printnonzero(float *A, int num_values, const char*strval); +template __device__ void printnonzero(half *A, int num_values, const char*strval); __device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}; template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) @@ -3315,6 +3315,7 @@ template __global__ void kgemm_4bit_inference(int M, i using namespace nvcuda; int col_offset = blockIdx.x *32; const int warp_id = threadIdx.x / 32; + const int warp_idx = threadIdx.x % 32; const int half_warp_id = threadIdx.x / 16; const int half_warp_lane = threadIdx.x % 16; const int batch_size_warps = (WARPS-1)*2; @@ -3324,23 +3325,30 @@ template __global__ void kgemm_4bit_inference(int M, i #pragma unroll 16 for(int i = 0; i < 16; i++) quant_map[i] = nf4_data[i]; + //__shared__ T quant_map[16*160]; T local_A[2]; T local_B[64]; unsigned char local_B_4bit[32]; + const int a_tile_offset = 16; const int b_tile_offset = (16*32 + 16); - __shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_A[8*16 + (16*(batch_size_warps-1))]; __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; - //__shared__ T smem_C[8*32]; + __shared__ T smem_C[8*32]; wmma::fragment a_frag; wmma::fragment b_frag; wmma::fragment c_frag; wmma::fill_fragment(c_frag, 0.0f); + for(int i = threadIdx.x; i < (8*32); i+=blockDim.x) + smem_C[i] = 0.0f; + + __syncthreads(); + int ticktock = 0; int idx = 0 + threadIdx.x; int loaded_values = 0; @@ -3366,8 +3374,17 @@ template __global__ void kgemm_4bit_inference(int M, i #pragma unroll 64 for(int col = 0; col < 64; col+=2) { - local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f); - local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f); + //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f); + //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f); + //local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0); + //local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0); + //local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0); + //local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0); + + //local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0); + //local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0); + local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(17.0); + local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(17.0); } } @@ -3391,13 +3408,17 @@ template __global__ void kgemm_4bit_inference(int M, i smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; } ticktock = ticktock == 0 ? 1 : 0; + //if(threadIdx.x == 0) + //printf("aa %i %i\n", idx, loaded_values); //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) { idx = base_idx + threadIdx.x; + //if(threadIdx.x == 0) + //printf("%i %i\n", idx, loaded_values); - __syncthreads(); + //__syncthreads(); if(idx < K && warp_id < (WARPS-1)) { if(loaded_values == 0) @@ -3425,11 +3446,17 @@ template __global__ void kgemm_4bit_inference(int M, i #pragma unroll 64 for(int col = 0; col < 64; col+=2) { - local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); - local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); - //local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx); - //local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx); + //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); + //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); + //local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx); + //local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx); + + //local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax); + //local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax); + local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx); + local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx); } + //printnonzero(local_B, 128, ""); } smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; @@ -3463,6 +3490,11 @@ template __global__ void kgemm_4bit_inference(int M, i } __syncthreads(); + //if(threadIdx.x == 0) + //{ + // printnonzero(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: "); + // printnonzero(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: "); + //} if(warp_id != (WARPS-1)){ return; } // only warp_id == (WARPS-1) from here int warp_lane = threadIdx.x % 32; @@ -3470,6 +3502,8 @@ template __global__ void kgemm_4bit_inference(int M, i ticktock = ticktock == 0 ? 1 : 0; for(int k = 0; k < batch_size_warps; k++) { + //if(warp_lane == 0) + //printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x); wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); @@ -3477,14 +3511,101 @@ template __global__ void kgemm_4bit_inference(int M, i // 129 mu if(warp_id == (WARPS-1)) - wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); + wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major); - printnonzero(smem_A, 32); + //printnonzero(smem_C, 32, ""); if(col_offset + warp_lane < M) - out[col_offset + warp_lane] = smem_A[warp_lane]; + out[col_offset + warp_lane] = smem_C[warp_lane]; } +#define num_values_4bit 16 +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + // per threadblock: + // load step-by-step in chunks of [64,warps]: 1x64 * [64,warps] -> [1,warps] + // 64 packed 8bit values per warp AND each warp loads one output for B -> 1x128 * 128x1 + // 4 warps -> 4 loads per iter + // 1x128 * 128x4 -> 1x4 outputs + typedef cub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage[4]; + + const int warp_idx = threadIdx.x / 32; + const int warp_lane = threadIdx.x % 32; + const int row_B = 4*blockIdx.x + warp_idx; + T local_C = T(0); + + T quant_map[16]; + #pragma unroll 16 + for(int i = 0; i < 16; i++) + quant_map[i] = nf4_data[i]; + + unsigned char local_B_4bit[num_values_4bit/2]; + T local_B[num_values_4bit]; + + // need to increase occupancy by splitting the rows, but can be done later + + // A: [1, K] + // B: [N, K] + for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit) + { + int offset_B = ldb*row_B + (inner_idx/2); + int absidx = (2*offset_B)/blocksize; + T local_absmax = __ldg(&(absmax[absidx])); + + //printf("%f %i %i %i %i %i %i\n", (float)local_absmax, absidx, lda*row_B, K, ldb, row_B, offset_B); + #pragma unroll + for(int k = 0; k < num_values_4bit/2; k++) + { + if((inner_idx/2) < K && row_B < M) + local_B_4bit[k] = B[offset_B + k]; + else + local_B_4bit[k] = 0b01110111; + } + + + //if(row_B < M) + //{ + // if((inner_idx/num_values_4bit) < K) + // reinterpret_cast(local_B_4bit)[0] = reinterpret_cast(B)[offset_B/(num_values_4bit/2)]; + // else + // { + // for(int k = 0; k < num_values_4bit/2; k++) + // { + // if((inner_idx/2) < K && row_B < M) + // local_B_4bit[k] = B[offset_B + k]; + // else + // local_B_4bit[k] = 0b01110111; + // } + // } + //} + + + + #pragma unroll + for(int k = 0; k < num_values_4bit; k++) + { + local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax; + local_B[k*2+ 1] = quant_map[local_B_4bit[k] & 0x0F]*local_absmax; + } + + //printnonzero(local_B, 4, "B values: "); + + #pragma unroll + for(int k = 0; k < num_values_4bit; k++) + local_C += A[inner_idx + k]*local_B[k]; + + } + + local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); + + if(row_B < M && warp_lane == 0) + out[row_B] = local_C; + +} + + //#define ROWS 2 //template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc) //{ @@ -3647,8 +3768,15 @@ template __global__ void gemm_device(int M, int N, int K, half * _ template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); + +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 30faf4a..05d0715 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -106,6 +106,7 @@ template __global__ voi template __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n); + __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n); @@ -124,6 +125,7 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kfunc(T *A, T *B, T value, long n); diff --git a/csrc/ops.cu b/csrc/ops.cu index 9c042fa..ed242c9 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -723,7 +723,28 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi //cout << m << endl; //cout << n << endl; //cout << k << endl; - kgemm_4bit_inference<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + kgemm_4bit_inference<<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); +} + +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + int num_blocks = (m+3)/4; + + cout << num_blocks << endl; + //cout << lda << endl; + //cout << ldb << endl; + //cout << ldc << endl; + + //cout << m << endl; + //cout << n << endl; + //cout << k << endl; + kgemm_4bit_inference_naive<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } @@ -747,6 +768,7 @@ template void func(float *A, float *B, float value, long n); template void func(float *A, float *B, float value, long n); template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); //template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 5b9a32b..e4df195 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -200,6 +200,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); template void func(T *A, T *B, T value, long n); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 23a0364..d42f17f 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -28,6 +28,9 @@ void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int l void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } +void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) +{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } + #define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func(A, B, value, n); } \ @@ -345,6 +348,9 @@ extern "C" void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } + void cgemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) + { gemm_4bit_inference_naive(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } + void *cget_managed_ptr(size_t bytes) { void *ptr; diff --git a/tests/test_functional.py b/tests/test_functional.py index 54ceed5..752dd1d 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1773,17 +1773,17 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): print("partial matmul", time.time() - t0) -batch_size = 1 -seqdim = 1 +batch_size = 32 +seqdim = 512+256 values = [] #values.append((batch_size, seqdim, 768, 4 * 768)) #values.append((batch_size, seqdim, 1024, 4*1024)) #values.append((batch_size, seqdim, 1536, 4*1536)) #values.append((batch_size, seqdim, 2048, 4*2048)) #values.append((batch_size, seqdim, 2560, 4*2560)) -values.append((batch_size, seqdim, 4096, 4*4096)) -values.append((batch_size, seqdim, 5120, 4*5120)) -values.append((batch_size, seqdim, 6656, 4*6656)) +#values.append((batch_size, seqdim, 4096, 4*4096)) +#values.append((batch_size, seqdim, 5120, 4*5120)) +#values.append((batch_size, seqdim, 6656, 4*6656)) values.append((batch_size, seqdim, 8192, 4*8192)) #values.append((batch_size, seqdim, 5140, 4*5140)) #values.append((batch_size, seqdim, 12288, 4*12288)) @@ -1827,19 +1827,19 @@ def test_bench_matmul(batch, seq, model, hidden): torch.cuda.synchronize() print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - bnb.matmul_4bit(A, B_fp4.t(), quant_state=state) - torch.cuda.synchronize() - print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # bnb.matmul_4bit(A, B_fp4.t(), quant_state=state) + #torch.cuda.synchronize() + #print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c) - torch.cuda.synchronize() - print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c) + #torch.cuda.synchronize() + #print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) torch.cuda.synchronize() t0 = time.time() @@ -1901,21 +1901,21 @@ def test_bench_matmul(batch, seq, model, hidden): #torch.cuda.synchronize() #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - linear8bit(A) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - linear8bit(A) - torch.cuda.synchronize() - print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #linear8bit(A) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # linear8bit(A) + #torch.cuda.synchronize() + #print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - linearMixedBit(A) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - linearMixedBit(A) - torch.cuda.synchronize() - print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #linearMixedBit(A) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # linearMixedBit(A) + #torch.cuda.synchronize() + #print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") #linear8bit_train(A) #torch.cuda.synchronize() @@ -2411,10 +2411,14 @@ def test_cutlass3_gemm(dtype): #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) def test_gemm_4bit(dtype): + print('') #for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: #for dim in [4096, 5120, 6656, 8192]: #for dim in [32]: - for dim in [32]: + for dim in [4096]: + #for dim in [5120]: + #for dim in [6656]: + #for dim in [128]: errs = [] relerrs = [] max_err = 0 @@ -2424,24 +2428,36 @@ def test_gemm_4bit(dtype): #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') #A = torch.rand(1, 4096, dtype=dtype, device='cuda') #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') - A = torch.randn(1, dim+0, dtype=dtype, device='cuda') - B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim) + A = torch.randn(1, dim+2, dtype=dtype, device='cuda') + B = torch.randn(4*dim, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim) + #B = torch.randn(1, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim) #print('') #print(A) #print(B.t()) #A[:, :-1] = 0 #B[:, :-1] = 0 + #A.flatten()[:-1] = 0 + #B.flatten()[:-1] = 0 qB, state = F.quantize_nf4(B) F.dequantize_nf4(qB, state) - C3 = torch.matmul(A, B.t()) + #C3 = torch.matmul(A, B.t()) C2 = F.cutlass3_gemm(A, qB.t(), state=state) C1 = bnb.matmul_4bit(A, qB.t(), state) - print(C1) - print(C2) + #print(state) + #print(qB) + + + #print('') + #print(A) + #print(B) + #print('='*89) + #print(C1) + #print(C2) + #print(C3) #print(C1.shape, C2.shape) @@ -2455,7 +2471,7 @@ def test_gemm_4bit(dtype): max_relerr = max(relerr.max(), max_relerr) err = err.mean().item() relerr = relerr.mean().item() - print(err) + #print(err) errs.append(err) relerrs.append(relerr) @@ -2463,20 +2479,20 @@ def test_gemm_4bit(dtype): if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: print('') print(i, err, relerr) - print(A.flatten()[-6:]) - print(B.flatten()[-6:]) - out = A.flatten()[-6:]*B.flatten()[-6:] - print(out) - print(out[:-1].sum()) + #print(A.flatten()[-6:]) + #print(B.flatten()[-6:]) + #out = A.flatten()[-6:]*B.flatten()[-6:] + #print(out) + #print(out[:-1].sum()) print('='*80) - print(C1.flatten()[-6:]) - print(C2.flatten()[-6:]) + #print(C1.flatten()[-6:]) + #print(C2.flatten()[-6:]) #assert False, 'ERROR' c = int(C1.numel()*0.0014*(dim/256))+1 c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False) - #print(c/math.sqrt(dim)) + print(c/math.sqrt(dim)) print('') print(dim, sum(errs)/len(errs)/math.sqrt(dim)) print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))