From b7f04e2a2064575d0c636a89d98a7075c46151e1 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 30 May 2023 20:07:05 -0700 Subject: [PATCH 01/13] Added lookup table. --- Makefile | 4 ++-- csrc/kernels.cu | 9 +++++++++ tests/test_functional.py | 12 ++++++++---- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index 5fa1f17..2cbb1b9 100644 --- a/Makefile +++ b/Makefile @@ -47,8 +47,8 @@ CC_cublasLt110 := -gencode arch=compute_75,code=sm_75 CC_cublasLt110 += -gencode arch=compute_80,code=sm_80 CC_cublasLt111 := -gencode arch=compute_75,code=sm_75 -CC_cublasLt111 += -gencode arch=compute_80,code=sm_80 -CC_cublasLt111 += -gencode arch=compute_86,code=sm_86 +#CC_cublasLt111 += -gencode arch=compute_80,code=sm_80 +#CC_cublasLt111 += -gencode arch=compute_86,code=sm_86 CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89 CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90 diff --git a/csrc/kernels.cu b/csrc/kernels.cu index ab12c37..7a752cb 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3297,6 +3297,7 @@ template __global__ void gemm_device(int M, #endif } +__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}; template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) { @@ -3308,6 +3309,12 @@ template __global__ void kgemm_4bit_inference(int M, i const int half_warp_lane = threadIdx.x % 16; const int batch_size_warps = (WARPS-1)*2; + T quant_map[16]; + + #pragma unroll 16 + for(int i = 0; i < 16; i++) + quant_map[i] = nf4_data[i]; + T local_A[2]; T local_B[64]; unsigned char local_B_4bit[32]; @@ -3410,6 +3417,8 @@ template __global__ void kgemm_4bit_inference(int M, i { local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); + //local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx); + //local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx); } } diff --git a/tests/test_functional.py b/tests/test_functional.py index cc58324..29b82e6 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2297,7 +2297,8 @@ def test_4bit_compressed_stats(quant_type): @pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") -@pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) +#@pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) +@pytest.mark.parametrize("quant_type", ['nf4']) def test_bench_4bit_dequant(quant_type): blocksize = 256 a = torch.rand(1024*12*4, 1024*12, device='cuda').half() @@ -2311,7 +2312,7 @@ def test_bench_4bit_dequant(quant_type): #print(max_theoretical_s*1e6) b = torch.randn(128, 1024*12, device='cuda').half() - iters = 5 + iters = 100 torch.cuda.synchronize() t0 = time.time() for i in range(iters): @@ -2438,9 +2439,11 @@ def test_gemm_4bit(dtype): C3 = torch.matmul(A, B.t()) C2 = F.cutlass3_gemm(A, qB.t(), state=state) C1 = bnb.matmul_4bit(A, qB.t(), state) - C2 = F.cutlass3_gemm(A, qB.t(), state=state) - print(C1.shape, C2.shape) + print(C1) + print(C2) + + #print(C1.shape, C2.shape) # tensor cores are non-deterministic # so we need to analyze errors around the mean @@ -2452,6 +2455,7 @@ def test_gemm_4bit(dtype): max_relerr = max(relerr.max(), max_relerr) err = err.mean().item() relerr = relerr.mean().item() + print(err) errs.append(err) relerrs.append(relerr) From e54d2730fc033489be1ee61dab5ac5e22f798527 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 30 May 2023 20:42:21 -0700 Subject: [PATCH 02/13] Added debugging functions. --- csrc/kernels.cu | 15 +++++++++++++-- tests/test_functional.py | 2 +- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 7a752cb..ea0be06 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3297,11 +3297,21 @@ template __global__ void gemm_device(int M, #endif } + +template __device__ void printnonzero(T *A, int num_values) +{ + for(int i = 0; i < num_values; i++) + if((float)A[i] != 0.0) + printf("%i %f\n", i, (float)A[i]); +} + +template __device__ void printnonzero(float *A, int num_values); +template __device__ void printnonzero(half *A, int num_values); + __device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}; template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) { -#if __CUDA_ARCH__ >= 750 using namespace nvcuda; int col_offset = blockIdx.x *32; const int warp_id = threadIdx.x / 32; @@ -3469,9 +3479,10 @@ template __global__ void kgemm_4bit_inference(int M, i if(warp_id == (WARPS-1)) wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); + printnonzero(smem_A, 32); + if(col_offset + warp_lane < M) out[col_offset + warp_lane] = smem_A[warp_lane]; -#endif } //#define ROWS 2 diff --git a/tests/test_functional.py b/tests/test_functional.py index 29b82e6..54ceed5 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2414,7 +2414,7 @@ def test_gemm_4bit(dtype): #for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: #for dim in [4096, 5120, 6656, 8192]: #for dim in [32]: - for dim in [4096]: + for dim in [32]: errs = [] relerrs = [] max_err = 0 From f89ff93e26d02037db30e88053983d6bb12dd660 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 3 Jul 2023 18:45:38 -0700 Subject: [PATCH 03/13] Initial 4-bit naive batch size 1, 81 vs 185. --- bitsandbytes/functional.py | 2 +- csrc/kernels.cu | 162 +++++++++++++++++++++++++++++++++---- csrc/kernels.cuh | 2 + csrc/ops.cu | 24 +++++- csrc/ops.cuh | 1 + csrc/pythonInterface.c | 6 ++ tests/test_functional.py | 108 ++++++++++++++----------- 7 files changed, 240 insertions(+), 65 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index afa346e..3ae4237 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1503,7 +1503,7 @@ def cutlass3_gemm( ldc = ct.c_int32(ldc) if B.dtype == torch.uint8: - lib.cgemm_4bit_inference(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) + lib.cgemm_4bit_inference_naive(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) elif A.dtype == torch.float32: lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) elif A.dtype == torch.float16: diff --git a/csrc/kernels.cu b/csrc/kernels.cu index ea0be06..216d436 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3088,7 +3088,7 @@ template __device__ inline void vector_l } } -#define WARPS 5 +#define WARPS 3 template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) { @@ -3298,15 +3298,15 @@ template __global__ void gemm_device(int M, } -template __device__ void printnonzero(T *A, int num_values) +template __device__ void printnonzero(T *A, int num_values, const char * strval) { for(int i = 0; i < num_values; i++) if((float)A[i] != 0.0) - printf("%i %f\n", i, (float)A[i]); + printf("%s %i %f\n", strval, i, (float)A[i]); } -template __device__ void printnonzero(float *A, int num_values); -template __device__ void printnonzero(half *A, int num_values); +template __device__ void printnonzero(float *A, int num_values, const char*strval); +template __device__ void printnonzero(half *A, int num_values, const char*strval); __device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}; template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) @@ -3315,6 +3315,7 @@ template __global__ void kgemm_4bit_inference(int M, i using namespace nvcuda; int col_offset = blockIdx.x *32; const int warp_id = threadIdx.x / 32; + const int warp_idx = threadIdx.x % 32; const int half_warp_id = threadIdx.x / 16; const int half_warp_lane = threadIdx.x % 16; const int batch_size_warps = (WARPS-1)*2; @@ -3324,23 +3325,30 @@ template __global__ void kgemm_4bit_inference(int M, i #pragma unroll 16 for(int i = 0; i < 16; i++) quant_map[i] = nf4_data[i]; + //__shared__ T quant_map[16*160]; T local_A[2]; T local_B[64]; unsigned char local_B_4bit[32]; + const int a_tile_offset = 16; const int b_tile_offset = (16*32 + 16); - __shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_A[8*16 + (16*(batch_size_warps-1))]; __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; - //__shared__ T smem_C[8*32]; + __shared__ T smem_C[8*32]; wmma::fragment a_frag; wmma::fragment b_frag; wmma::fragment c_frag; wmma::fill_fragment(c_frag, 0.0f); + for(int i = threadIdx.x; i < (8*32); i+=blockDim.x) + smem_C[i] = 0.0f; + + __syncthreads(); + int ticktock = 0; int idx = 0 + threadIdx.x; int loaded_values = 0; @@ -3366,8 +3374,17 @@ template __global__ void kgemm_4bit_inference(int M, i #pragma unroll 64 for(int col = 0; col < 64; col+=2) { - local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f); - local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f); + //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f); + //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f); + //local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0); + //local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0); + //local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0); + //local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0); + + //local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0); + //local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0); + local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(17.0); + local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(17.0); } } @@ -3391,13 +3408,17 @@ template __global__ void kgemm_4bit_inference(int M, i smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; } ticktock = ticktock == 0 ? 1 : 0; + //if(threadIdx.x == 0) + //printf("aa %i %i\n", idx, loaded_values); //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) { idx = base_idx + threadIdx.x; + //if(threadIdx.x == 0) + //printf("%i %i\n", idx, loaded_values); - __syncthreads(); + //__syncthreads(); if(idx < K && warp_id < (WARPS-1)) { if(loaded_values == 0) @@ -3425,11 +3446,17 @@ template __global__ void kgemm_4bit_inference(int M, i #pragma unroll 64 for(int col = 0; col < 64; col+=2) { - local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); - local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); - //local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx); - //local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx); + //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); + //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); + //local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx); + //local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx); + + //local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax); + //local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax); + local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx); + local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx); } + //printnonzero(local_B, 128, ""); } smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; @@ -3463,6 +3490,11 @@ template __global__ void kgemm_4bit_inference(int M, i } __syncthreads(); + //if(threadIdx.x == 0) + //{ + // printnonzero(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: "); + // printnonzero(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: "); + //} if(warp_id != (WARPS-1)){ return; } // only warp_id == (WARPS-1) from here int warp_lane = threadIdx.x % 32; @@ -3470,6 +3502,8 @@ template __global__ void kgemm_4bit_inference(int M, i ticktock = ticktock == 0 ? 1 : 0; for(int k = 0; k < batch_size_warps; k++) { + //if(warp_lane == 0) + //printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x); wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); @@ -3477,14 +3511,101 @@ template __global__ void kgemm_4bit_inference(int M, i // 129 mu if(warp_id == (WARPS-1)) - wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); + wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major); - printnonzero(smem_A, 32); + //printnonzero(smem_C, 32, ""); if(col_offset + warp_lane < M) - out[col_offset + warp_lane] = smem_A[warp_lane]; + out[col_offset + warp_lane] = smem_C[warp_lane]; } +#define num_values_4bit 16 +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + // per threadblock: + // load step-by-step in chunks of [64,warps]: 1x64 * [64,warps] -> [1,warps] + // 64 packed 8bit values per warp AND each warp loads one output for B -> 1x128 * 128x1 + // 4 warps -> 4 loads per iter + // 1x128 * 128x4 -> 1x4 outputs + typedef cub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage[4]; + + const int warp_idx = threadIdx.x / 32; + const int warp_lane = threadIdx.x % 32; + const int row_B = 4*blockIdx.x + warp_idx; + T local_C = T(0); + + T quant_map[16]; + #pragma unroll 16 + for(int i = 0; i < 16; i++) + quant_map[i] = nf4_data[i]; + + unsigned char local_B_4bit[num_values_4bit/2]; + T local_B[num_values_4bit]; + + // need to increase occupancy by splitting the rows, but can be done later + + // A: [1, K] + // B: [N, K] + for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit) + { + int offset_B = ldb*row_B + (inner_idx/2); + int absidx = (2*offset_B)/blocksize; + T local_absmax = __ldg(&(absmax[absidx])); + + //printf("%f %i %i %i %i %i %i\n", (float)local_absmax, absidx, lda*row_B, K, ldb, row_B, offset_B); + #pragma unroll + for(int k = 0; k < num_values_4bit/2; k++) + { + if((inner_idx/2) < K && row_B < M) + local_B_4bit[k] = B[offset_B + k]; + else + local_B_4bit[k] = 0b01110111; + } + + + //if(row_B < M) + //{ + // if((inner_idx/num_values_4bit) < K) + // reinterpret_cast(local_B_4bit)[0] = reinterpret_cast(B)[offset_B/(num_values_4bit/2)]; + // else + // { + // for(int k = 0; k < num_values_4bit/2; k++) + // { + // if((inner_idx/2) < K && row_B < M) + // local_B_4bit[k] = B[offset_B + k]; + // else + // local_B_4bit[k] = 0b01110111; + // } + // } + //} + + + + #pragma unroll + for(int k = 0; k < num_values_4bit; k++) + { + local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax; + local_B[k*2+ 1] = quant_map[local_B_4bit[k] & 0x0F]*local_absmax; + } + + //printnonzero(local_B, 4, "B values: "); + + #pragma unroll + for(int k = 0; k < num_values_4bit; k++) + local_C += A[inner_idx + k]*local_B[k]; + + } + + local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); + + if(row_B < M && warp_lane == 0) + out[row_B] = local_C; + +} + + //#define ROWS 2 //template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc) //{ @@ -3647,8 +3768,15 @@ template __global__ void gemm_device(int M, int N, int K, half * _ template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); + +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 30faf4a..05d0715 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -106,6 +106,7 @@ template __global__ voi template __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n); + __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n); @@ -124,6 +125,7 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kfunc(T *A, T *B, T value, long n); diff --git a/csrc/ops.cu b/csrc/ops.cu index 9c042fa..ed242c9 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -723,7 +723,28 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi //cout << m << endl; //cout << n << endl; //cout << k << endl; - kgemm_4bit_inference<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + kgemm_4bit_inference<<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); +} + +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + int num_blocks = (m+3)/4; + + cout << num_blocks << endl; + //cout << lda << endl; + //cout << ldb << endl; + //cout << ldc << endl; + + //cout << m << endl; + //cout << n << endl; + //cout << k << endl; + kgemm_4bit_inference_naive<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } @@ -747,6 +768,7 @@ template void func(float *A, float *B, float value, long n); template void func(float *A, float *B, float value, long n); template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); //template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 5b9a32b..e4df195 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -200,6 +200,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); template void func(T *A, T *B, T value, long n); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 23a0364..d42f17f 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -28,6 +28,9 @@ void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int l void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } +void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) +{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } + #define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func(A, B, value, n); } \ @@ -345,6 +348,9 @@ extern "C" void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } + void cgemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) + { gemm_4bit_inference_naive(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } + void *cget_managed_ptr(size_t bytes) { void *ptr; diff --git a/tests/test_functional.py b/tests/test_functional.py index 54ceed5..752dd1d 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1773,17 +1773,17 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): print("partial matmul", time.time() - t0) -batch_size = 1 -seqdim = 1 +batch_size = 32 +seqdim = 512+256 values = [] #values.append((batch_size, seqdim, 768, 4 * 768)) #values.append((batch_size, seqdim, 1024, 4*1024)) #values.append((batch_size, seqdim, 1536, 4*1536)) #values.append((batch_size, seqdim, 2048, 4*2048)) #values.append((batch_size, seqdim, 2560, 4*2560)) -values.append((batch_size, seqdim, 4096, 4*4096)) -values.append((batch_size, seqdim, 5120, 4*5120)) -values.append((batch_size, seqdim, 6656, 4*6656)) +#values.append((batch_size, seqdim, 4096, 4*4096)) +#values.append((batch_size, seqdim, 5120, 4*5120)) +#values.append((batch_size, seqdim, 6656, 4*6656)) values.append((batch_size, seqdim, 8192, 4*8192)) #values.append((batch_size, seqdim, 5140, 4*5140)) #values.append((batch_size, seqdim, 12288, 4*12288)) @@ -1827,19 +1827,19 @@ def test_bench_matmul(batch, seq, model, hidden): torch.cuda.synchronize() print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - bnb.matmul_4bit(A, B_fp4.t(), quant_state=state) - torch.cuda.synchronize() - print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # bnb.matmul_4bit(A, B_fp4.t(), quant_state=state) + #torch.cuda.synchronize() + #print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c) - torch.cuda.synchronize() - print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c) + #torch.cuda.synchronize() + #print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) torch.cuda.synchronize() t0 = time.time() @@ -1901,21 +1901,21 @@ def test_bench_matmul(batch, seq, model, hidden): #torch.cuda.synchronize() #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - linear8bit(A) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - linear8bit(A) - torch.cuda.synchronize() - print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #linear8bit(A) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # linear8bit(A) + #torch.cuda.synchronize() + #print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - linearMixedBit(A) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - linearMixedBit(A) - torch.cuda.synchronize() - print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #linearMixedBit(A) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # linearMixedBit(A) + #torch.cuda.synchronize() + #print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") #linear8bit_train(A) #torch.cuda.synchronize() @@ -2411,10 +2411,14 @@ def test_cutlass3_gemm(dtype): #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) def test_gemm_4bit(dtype): + print('') #for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: #for dim in [4096, 5120, 6656, 8192]: #for dim in [32]: - for dim in [32]: + for dim in [4096]: + #for dim in [5120]: + #for dim in [6656]: + #for dim in [128]: errs = [] relerrs = [] max_err = 0 @@ -2424,24 +2428,36 @@ def test_gemm_4bit(dtype): #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') #A = torch.rand(1, 4096, dtype=dtype, device='cuda') #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') - A = torch.randn(1, dim+0, dtype=dtype, device='cuda') - B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim) + A = torch.randn(1, dim+2, dtype=dtype, device='cuda') + B = torch.randn(4*dim, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim) + #B = torch.randn(1, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim) #print('') #print(A) #print(B.t()) #A[:, :-1] = 0 #B[:, :-1] = 0 + #A.flatten()[:-1] = 0 + #B.flatten()[:-1] = 0 qB, state = F.quantize_nf4(B) F.dequantize_nf4(qB, state) - C3 = torch.matmul(A, B.t()) + #C3 = torch.matmul(A, B.t()) C2 = F.cutlass3_gemm(A, qB.t(), state=state) C1 = bnb.matmul_4bit(A, qB.t(), state) - print(C1) - print(C2) + #print(state) + #print(qB) + + + #print('') + #print(A) + #print(B) + #print('='*89) + #print(C1) + #print(C2) + #print(C3) #print(C1.shape, C2.shape) @@ -2455,7 +2471,7 @@ def test_gemm_4bit(dtype): max_relerr = max(relerr.max(), max_relerr) err = err.mean().item() relerr = relerr.mean().item() - print(err) + #print(err) errs.append(err) relerrs.append(relerr) @@ -2463,20 +2479,20 @@ def test_gemm_4bit(dtype): if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: print('') print(i, err, relerr) - print(A.flatten()[-6:]) - print(B.flatten()[-6:]) - out = A.flatten()[-6:]*B.flatten()[-6:] - print(out) - print(out[:-1].sum()) + #print(A.flatten()[-6:]) + #print(B.flatten()[-6:]) + #out = A.flatten()[-6:]*B.flatten()[-6:] + #print(out) + #print(out[:-1].sum()) print('='*80) - print(C1.flatten()[-6:]) - print(C2.flatten()[-6:]) + #print(C1.flatten()[-6:]) + #print(C2.flatten()[-6:]) #assert False, 'ERROR' c = int(C1.numel()*0.0014*(dim/256))+1 c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False) - #print(c/math.sqrt(dim)) + print(c/math.sqrt(dim)) print('') print(dim, sum(errs)/len(errs)/math.sqrt(dim)) print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim)) From dfe6900b94a0b38c649ea39b2dd12392c835195f Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 4 Jul 2023 15:20:10 -0700 Subject: [PATCH 04/13] Vectorized loads, conflict free NF4; 52 vs 172. --- csrc/kernels.cu | 79 +++++++++++++++++++--------------------- csrc/ops.cu | 1 + tests/test_functional.py | 10 ++--- 3 files changed, 44 insertions(+), 46 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 216d436..34e552b 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3519,7 +3519,7 @@ template __global__ void kgemm_4bit_inference(int M, i out[col_offset + warp_lane] = smem_C[warp_lane]; } -#define num_values_4bit 16 +#define num_values_4bit 32 template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) { @@ -3529,72 +3529,68 @@ template __global__ void kgemm_4bit_inference_naive(in // 4 warps -> 4 loads per iter // 1x128 * 128x4 -> 1x4 outputs typedef cub::WarpReduce WarpReduce; - __shared__ typename WarpReduce::TempStorage temp_storage[4]; + __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32]; const int warp_idx = threadIdx.x / 32; const int warp_lane = threadIdx.x % 32; - const int row_B = 4*blockIdx.x + warp_idx; + const int row_B = (THREADS/32)*blockIdx.x + warp_idx; + const int num_values_8bit = num_values_4bit/2; T local_C = T(0); - T quant_map[16]; - #pragma unroll 16 - for(int i = 0; i < 16; i++) - quant_map[i] = nf4_data[i]; - - unsigned char local_B_4bit[num_values_4bit/2]; + unsigned char local_B_4bit[num_values_8bit]; T local_B[num_values_4bit]; + T local_A[num_values_4bit]; + __shared__ half quant_map[16*THREADS]; - // need to increase occupancy by splitting the rows, but can be done later + for(int i = 0; i < 16; i++) + quant_map[threadIdx.x + (i*blockDim.x)] = nf4_data[i]; + __syncthreads(); // A: [1, K] // B: [N, K] for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit) { - int offset_B = ldb*row_B + (inner_idx/2); - int absidx = (2*offset_B)/blocksize; + int inner_idx_halved = inner_idx/2; + int offset_B = ldb*row_B; + int absidx = ((2*offset_B)+inner_idx)/blocksize; T local_absmax = __ldg(&(absmax[absidx])); - //printf("%f %i %i %i %i %i %i\n", (float)local_absmax, absidx, lda*row_B, K, ldb, row_B, offset_B); - #pragma unroll - for(int k = 0; k < num_values_4bit/2; k++) + if(row_B < M) { - if((inner_idx/2) < K && row_B < M) - local_B_4bit[k] = B[offset_B + k]; + if((inner_idx_halved + num_values_8bit) < K) + { + reinterpret_cast(local_B_4bit)[0] = reinterpret_cast(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)]; + } else - local_B_4bit[k] = 0b01110111; + { + #pragma unroll + for(int j = 0; j < (num_values_8bit); j++) + if((inner_idx/2) + j < K) + local_B_4bit[j] = 0b01110111; + } } - - //if(row_B < M) - //{ - // if((inner_idx/num_values_4bit) < K) - // reinterpret_cast(local_B_4bit)[0] = reinterpret_cast(B)[offset_B/(num_values_4bit/2)]; - // else - // { - // for(int k = 0; k < num_values_4bit/2; k++) - // { - // if((inner_idx/2) < K && row_B < M) - // local_B_4bit[k] = B[offset_B + k]; - // else - // local_B_4bit[k] = 0b01110111; - // } - // } - //} - - - #pragma unroll for(int k = 0; k < num_values_4bit; k++) { - local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax; - local_B[k*2+ 1] = quant_map[local_B_4bit[k] & 0x0F]*local_absmax; + local_B[k*2] = quant_map[(local_B_4bit[k] >> 4)*THREADS+threadIdx.x]*local_absmax; + local_B[k*2+ 1] = quant_map[(local_B_4bit[k] & 0x0F)*THREADS+threadIdx.x]*local_absmax; } - //printnonzero(local_B, 4, "B values: "); + if(inner_idx+num_values_4bit) + { + reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_8bit/2) + 0]; + reinterpret_cast(local_A)[1] = reinterpret_cast(A)[inner_idx/(num_values_8bit/2) + 1]; + reinterpret_cast(local_A)[2] = reinterpret_cast(A)[inner_idx/(num_values_8bit/2) + 2]; + reinterpret_cast(local_A)[3] = reinterpret_cast(A)[inner_idx/(num_values_8bit/2) + 3]; + } + else + for(int k = 0; k < num_values_4bit; k++) + local_A[k] = A[inner_idx +k]; #pragma unroll for(int k = 0; k < num_values_4bit; k++) - local_C += A[inner_idx + k]*local_B[k]; + local_C += local_A[k]*local_B[k]; } @@ -3773,6 +3769,7 @@ template __global__ void kgemm_4bit_inference(int M, int N, int K, ha template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); diff --git a/csrc/ops.cu b/csrc/ops.cu index ed242c9..c30e979 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -733,6 +733,7 @@ template void gemm_4bit_inference_naive(int m, int n, int k, T * A, { int num_blocks = (m+3)/4; + //int num_blocks = m; cout << num_blocks << endl; //cout << lda << endl; diff --git a/tests/test_functional.py b/tests/test_functional.py index 752dd1d..598b995 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2415,21 +2415,21 @@ def test_gemm_4bit(dtype): #for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: #for dim in [4096, 5120, 6656, 8192]: #for dim in [32]: - for dim in [4096]: + for dim in [2*4096]: #for dim in [5120]: #for dim in [6656]: - #for dim in [128]: + #for dim in [4]: errs = [] relerrs = [] max_err = 0 max_relerr = 0 - for i in range(1): + for i in range(100): #A = torch.rand(2, 4092, dtype=dtype, device='cuda') #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') #A = torch.rand(1, 4096, dtype=dtype, device='cuda') #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') - A = torch.randn(1, dim+2, dtype=dtype, device='cuda') - B = torch.randn(4*dim, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim) + A = torch.randn(1, dim, dtype=dtype, device='cuda') + B = torch.randn(4*dim, dim, dtype=dtype, device='cuda')/math.sqrt(dim) #B = torch.randn(1, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim) #print('') From 02fd80cb814285984415fd903278b8217c18c4df Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 4 Jul 2023 19:58:31 -0700 Subject: [PATCH 05/13] Added bfloat16 quantizations and tests. --- bitsandbytes/autograd/_functions.py | 5 +- bitsandbytes/functional.py | 46 ++++++++++------ csrc/kernels.cu | 65 +++++++++++++++------- csrc/ops.cu | 33 +++++------ csrc/pythonInterface.c | 62 +++++++++++++++------ tests/test_functional.py | 85 +++++++++++++---------------- 6 files changed, 175 insertions(+), 121 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 63b7156..eeef93b 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -561,4 +561,7 @@ def matmul( def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None): assert quant_state is not None - return MatMul4Bit.apply(A, B, out, bias, quant_state) + if A.numel() == A.shape[-1] and A.requires_grad == False: + return F.cutlass3_gemm(A, B.t(), out, state=quant_state) + else: + return MatMul4Bit.apply(A, B, out, bias, quant_state) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 3ae4237..95a15d5 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -617,6 +617,8 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) elif A.dtype == torch.float16: lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) + elif A.dtype == torch.bfloat16: + lib.cquantize_blockwise_bf16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) @@ -629,11 +631,9 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou offset = absmax.mean() absmax -= offset qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False) - state = [qabsmax, code, blocksize, nested, offset, state2] + state = [qabsmax, code, blocksize, nested, A.dtype, offset, state2] else: - state = [absmax, code, blocksize, nested, None, None] - - + state = [absmax, code, blocksize, nested, A.dtype, None, None] return out, state @@ -678,18 +678,16 @@ def dequantize_blockwise( name2qmap["dynamic"] = create_dynamic_map().to(A.device) code = name2qmap["dynamic"] - if out is None: - out = torch.zeros_like(A, dtype=torch.float32) - if quant_state is None: - quant_state = (absmax, code, blocksize) - assert absmax is not None and out is not None - else: - absmax, code, blocksize, nested, offset, state2 = quant_state - if nested: - absmax = dequantize_blockwise(absmax, state2) - absmax += offset + quant_state = (absmax, code, blocksize, False, torch.float32, None, None) + absmax, code, blocksize, nested, dtype, offset, state2 = quant_state + if nested: + absmax = dequantize_blockwise(absmax, state2) + absmax += offset + + if out is None: + out = torch.empty(A.shape, dtype=dtype, device=A.device) if A.device.type != 'cpu': device = pre_call(A.device) @@ -701,6 +699,8 @@ def dequantize_blockwise( lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) elif out.dtype == torch.float16: lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) + elif out.dtype == torch.bfloat16: + lib.cdequantize_blockwise_bf16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) @@ -774,6 +774,11 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) else: lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + elif A.dtype == torch.bfloat16: + if quant_type == 'fp4': + lib.cquantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + else: + lib.cquantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) @@ -860,6 +865,11 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) else: lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + elif out.dtype == torch.bfloat16: + if quant_type == 'fp4': + lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + else: + lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) @@ -1503,7 +1513,12 @@ def cutlass3_gemm( ldc = ct.c_int32(ldc) if B.dtype == torch.uint8: - lib.cgemm_4bit_inference_naive(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) + if A.dtype == torch.float16: + lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) + elif A.dtype == torch.bfloat16: + lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) + else: + raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') elif A.dtype == torch.float32: lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) elif A.dtype == torch.float16: @@ -1515,7 +1530,6 @@ def cutlass3_gemm( - def igemm( A: Tensor, B: Tensor, diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 34e552b..d443192 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3540,7 +3540,7 @@ template __global__ void kgemm_4bit_inference_naive(in unsigned char local_B_4bit[num_values_8bit]; T local_B[num_values_4bit]; T local_A[num_values_4bit]; - __shared__ half quant_map[16*THREADS]; + __shared__ T quant_map[16*THREADS]; for(int i = 0; i < 16; i++) quant_map[threadIdx.x + (i*blockDim.x)] = nf4_data[i]; @@ -3769,11 +3769,8 @@ template __global__ void kgemm_4bit_inference(int M, int N, int K, ha template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); @@ -3929,6 +3926,20 @@ MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit) MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit) MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit) MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4) MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit) @@ -3937,13 +3948,6 @@ MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4) -MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4) -MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4) -MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4) -MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4) -MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4) -MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4) MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4) MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4) MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) @@ -3951,13 +3955,6 @@ MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) -MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4) -MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4) -MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4) -MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4) -MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4) -MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4) -MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4) MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4) MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4) MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) @@ -3966,12 +3963,38 @@ MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, NF4) + template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n); #define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ template __global__ void kOptimizerStatic8bit2StateBlockwise(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \ diff --git a/csrc/ops.cu b/csrc/ops.cu index c30e979..902129f 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -733,20 +733,8 @@ template void gemm_4bit_inference_naive(int m, int n, int k, T * A, { int num_blocks = (m+3)/4; - //int num_blocks = m; - cout << num_blocks << endl; - //cout << lda << endl; - //cout << ldb << endl; - //cout << ldc << endl; - - //cout << m << endl; - //cout << n << endl; - //cout << k << endl; kgemm_4bit_inference_naive<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); - //kgemm_4bit_inference<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); - //kgemm_4bit_inference<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); - //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } template void func(T *A, T *B, T value, long n) @@ -770,6 +758,7 @@ template void func(float *A, float *B, float value, long n); template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive<__nv_bfloat16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); //template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); @@ -796,19 +785,27 @@ template void estimateQuantiles(half *A, float *code, float offset, int n); template void estimateQuantiles(float *A, float *code, float offset, int n); template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void quantizeBlockwise<__nv_bfloat16, 1, General8bit>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise<__nv_bfloat16, 0, General8bit>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise<__nv_bfloat16, 0, FP4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise<__nv_bfloat16, 0, NF4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); + template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise<__nv_bfloat16, General8bit>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n); +template void dequantizeBlockwise<__nv_bfloat16, FP4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n); +template void dequantizeBlockwise<__nv_bfloat16, NF4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n); #define MAKE_optimizer32bit(name, gtype) \ template void optimizer32bit(gtype* g, gtype* p, \ diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index d42f17f..4cbabae 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -28,9 +28,12 @@ void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int l void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } -void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) +void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) { gemm_4bit_inference_naive(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } +void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize) +{ gemm_4bit_inference_naive<__nv_bfloat16>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } + #define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func(A, B, value, n); } \ @@ -106,19 +109,29 @@ void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } -void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } -void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } + +void quantizeBlockwise_bf16(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_bf16_fp4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_bf16_nf4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); } + +void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } \ -void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } \ -void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } void dequantizeBlockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } \ + +void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } +void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } +void dequantizeBlockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise<__nv_bfloat16, General8bit>(code, A, absmax, out, blocksize, n); } +void dequantizeBlockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n); } +void dequantizeBlockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n); } + #define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \ @@ -177,21 +190,31 @@ extern "C" void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); } void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); } - void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } - void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } - void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } - void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } - - void cquantize_blockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); } - void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); } - void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } - void cquantize_blockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); } - void cquantize_blockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); } + + void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); } + + void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); } + + void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_bf16(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_bf16_fp4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_bf16_nf4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); } + + void cdequantize_blockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); } + #define MAKE_CFUNC32(name, gtype, gbits) \ void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \ float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ @@ -348,9 +371,6 @@ extern "C" void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } - void cgemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) - { gemm_4bit_inference_naive(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } - void *cget_managed_ptr(size_t bytes) { void *ptr; @@ -374,6 +394,12 @@ extern "C" CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) + void cgemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) + { gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } + + void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize) + { gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } + #endif void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } diff --git a/tests/test_functional.py b/tests/test_functional.py index 598b995..cef741e 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -154,34 +154,36 @@ def test_dynamic_quantization(): +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"]) @pytest.mark.parametrize("nested", [False, True], ids=["False", "True"]) @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64]) -def test_dynamic_blockwise_quantization(nested, blocksize): +def test_dynamic_blockwise_quantization(dtype, nested, blocksize): #print('') diffs = [] reldiffs = [] for i in range(100): - A1 = torch.randn(1024, 1024, device="cuda") + A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype) C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested) A2 = F.dequantize_blockwise(C, S) - diff = torch.abs(A1 - A2) - reldiff = diff / torch.abs(A1 + 1e-8) + diff = torch.abs(A1 - A2).float() + reldiff = diff / torch.abs(A1.float() + 1e-8) diffs.append(diff.mean().item()) reldiffs.append(reldiff.mean().item()) abserr = sum(diffs)/len(diffs) relerr = sum(reldiffs)/len(reldiffs) + #print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs)) + #print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs)) assert abserr < 0.011 assert relerr < 0.018 - #print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs)) - #print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs)) + assert A2.dtype == dtype diffs = [] for i in range(100): - A1 = torch.rand(1024, 1024, device="cuda") + A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype) C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested) A2 = F.dequantize_blockwise(C, S) - diff = torch.abs(A1 - A2) - reldiff = diff / torch.abs(A1 + 1e-8) + diff = torch.abs(A1 - A2).float() + reldiff = diff / torch.abs(A1.float() + 1e-8) diffs.append(diff.mean().item()) reldiffs.append(reldiff.mean().item()) #torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0) @@ -189,6 +191,7 @@ def test_dynamic_blockwise_quantization(nested, blocksize): relerr = sum(reldiffs)/len(reldiffs) assert abserr < 0.0035 assert relerr < 0.015 + assert A2.dtype == dtype #print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs)) #print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) @@ -1773,8 +1776,8 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): print("partial matmul", time.time() - t0) -batch_size = 32 -seqdim = 512+256 +batch_size = 1 +seqdim = 1 values = [] #values.append((batch_size, seqdim, 768, 4 * 768)) #values.append((batch_size, seqdim, 1024, 4*1024)) @@ -1800,7 +1803,7 @@ def test_bench_matmul(batch, seq, model, hidden): B_fp4, state = F.quantize_fp4(B) B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True) - B_nf4, state_nf4= F.quantize_nf4(B) + B_nf4, state_nf4 = F.quantize_nf4(B) linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half() linear8bit.eval() @@ -1813,6 +1816,7 @@ def test_bench_matmul(batch, seq, model, hidden): linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() + F.cutlass3_gemm(A, B_nf4.t(), state=state_nf4) # warmup for i in range(iters): @@ -1844,7 +1848,8 @@ def test_bench_matmul(batch, seq, model, hidden): torch.cuda.synchronize() t0 = time.time() for i in range(iters): - bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4) + #bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4) + F.cutlass3_gemm(A, B_nf4.t(), state=state_nf4) torch.cuda.synchronize() print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) @@ -2221,7 +2226,8 @@ def test_bench_dequantization(): -def test_fp4_quant(): +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"]) +def test_fp4_quant(dtype): vals = list(product([0, 1], repeat=4)) code = {} @@ -2243,7 +2249,7 @@ def test_fp4_quant(): result = sign*exp*frac code[idx] = result - A1 = torch.randn(1024, 1024, device='cuda').half() + A1 = torch.randn(1024, 1024, device='cuda', dtype=dtype) qa, SA = F.quantize_fp4(A1, blocksize=64) A2 = F.dequantize_fp4(qa, SA) @@ -2252,7 +2258,7 @@ def test_fp4_quant(): idx = err > 1.0 err = err.mean() - + assert A2.dtype == dtype assert err.item() < 0.1 assert relerr.item() < 0.28 @@ -2409,20 +2415,16 @@ def test_cutlass3_gemm(dtype): print(dim, (max_err.item(), max_relerr.item())) #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) -@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=['fp16', 'bf16']) +#@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16']) def test_gemm_4bit(dtype): print('') - #for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: - #for dim in [4096, 5120, 6656, 8192]: - #for dim in [32]: - for dim in [2*4096]: - #for dim in [5120]: - #for dim in [6656]: - #for dim in [4]: + for dim in [64, 128, 256, 512, 1024, 2048, 4096]: errs = [] relerrs = [] max_err = 0 max_relerr = 0 + for i in range(100): #A = torch.rand(2, 4092, dtype=dtype, device='cuda') #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') @@ -2443,14 +2445,13 @@ def test_gemm_4bit(dtype): qB, state = F.quantize_nf4(B) F.dequantize_nf4(qB, state) - #C3 = torch.matmul(A, B.t()) + #C2 = bnb.matmul_4bit(A, qB.t(), state) C2 = F.cutlass3_gemm(A, qB.t(), state=state) - C1 = bnb.matmul_4bit(A, qB.t(), state) + C1 = torch.matmul(A, B.t()) #print(state) #print(qB) - #print('') #print(A) #print(B) @@ -2464,8 +2465,8 @@ def test_gemm_4bit(dtype): # tensor cores are non-deterministic # so we need to analyze errors around the mean # to test our implementation - err = torch.abs(C1-C2) - mag = torch.abs(C1)+1e-8 + err = torch.abs(C1-C2).float() + mag = torch.abs(C1).float()+1e-5 relerr = err/mag max_err = max(err.max(), max_err) max_relerr = max(relerr.max(), max_relerr) @@ -2476,27 +2477,17 @@ def test_gemm_4bit(dtype): errs.append(err) relerrs.append(relerr) - if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: - print('') - print(i, err, relerr) - #print(A.flatten()[-6:]) - #print(B.flatten()[-6:]) - #out = A.flatten()[-6:]*B.flatten()[-6:] - #print(out) - #print(out[:-1].sum()) - print('='*80) - #print(C1.flatten()[-6:]) - #print(C2.flatten()[-6:]) - #assert False, 'ERROR' - c = int(C1.numel()*0.0014*(dim/256))+1 c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False) - print(c/math.sqrt(dim)) - print('') - print(dim, sum(errs)/len(errs)/math.sqrt(dim)) - print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim)) - print(dim, (max_err.item(), max_relerr.item())) + #print('') + #print(dim, sum(errs)/len(errs)/math.sqrt(dim)) + #print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim)) + #print(dim, (max_err.item(), max_relerr.item())) + #print(sum(errs)/len(errs)/math.sqrt(dim) , 0.00015) + #print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , 0.0015) + assert sum(errs)/len(errs)/math.sqrt(dim) < 0.011 + assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.15 @pytest.mark.skip("Row scale has some bugs for ampere") def test_managed(): From 7e49b5b9384042a5b4eec5a69abf45cfe0c3b8da Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sat, 8 Jul 2023 14:27:12 -0700 Subject: [PATCH 06/13] Added warp_shuffle indexing 185 vs 54. --- csrc/kernels.cu | 32 ++++++++++++++++++++++++++------ tests/test_functional.py | 7 ++++--- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index d443192..0a845d4 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3537,14 +3537,20 @@ template __global__ void kgemm_4bit_inference_naive(in const int num_values_8bit = num_values_4bit/2; T local_C = T(0); + T lane_quant_value = nf4_data[warp_lane % 16]; + unsigned char local_B_4bit[num_values_8bit]; T local_B[num_values_4bit]; T local_A[num_values_4bit]; __shared__ T quant_map[16*THREADS]; + __shared__ T quant_map2[16]; + + //for(int i = 0; i < 16; i++) + // quant_map[threadIdx.x + (i*blockDim.x)] = nf4_data[i]; + //__syncthreads(); for(int i = 0; i < 16; i++) - quant_map[threadIdx.x + (i*blockDim.x)] = nf4_data[i]; - __syncthreads(); + quant_map2[i] = nf4_data[i]; // A: [1, K] // B: [N, K] @@ -3570,11 +3576,25 @@ template __global__ void kgemm_4bit_inference_naive(in } } - #pragma unroll - for(int k = 0; k < num_values_4bit; k++) + if(inner_idx+(num_values_4bit*32) < K) { - local_B[k*2] = quant_map[(local_B_4bit[k] >> 4)*THREADS+threadIdx.x]*local_absmax; - local_B[k*2+ 1] = quant_map[(local_B_4bit[k] & 0x0F)*THREADS+threadIdx.x]*local_absmax; + // full warp is running + #pragma unroll + for(int k = 0; k < num_values_4bit; k++) + { + local_B[k*2] = __shfl_sync(0xffffffff, lane_quant_value, local_B_4bit[k] >> 4)*local_absmax; + local_B[k*2 + 1] = __shfl_sync(0xffffffff, lane_quant_value, local_B_4bit[k] & 0x0F)*local_absmax; + } + } + else + { + // part of the warp exited already + #pragma unroll + for(int k = 0; k < num_values_4bit; k++) + { + local_B[k*2] = quant_map2[(local_B_4bit[k] >> 4)]*local_absmax; + local_B[k*2 + 1] = quant_map2[(local_B_4bit[k] & 0x0F)]*local_absmax; + } } if(inner_idx+num_values_4bit) diff --git a/tests/test_functional.py b/tests/test_functional.py index cef741e..70d5515 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2419,7 +2419,8 @@ def test_cutlass3_gemm(dtype): #@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16']) def test_gemm_4bit(dtype): print('') - for dim in [64, 128, 256, 512, 1024, 2048, 4096]: + #for dim in [64, 128, 256, 512, 1024, 2048, 4096]: + for dim in [4096]: errs = [] relerrs = [] max_err = 0 @@ -2486,8 +2487,8 @@ def test_gemm_4bit(dtype): #print(dim, (max_err.item(), max_relerr.item())) #print(sum(errs)/len(errs)/math.sqrt(dim) , 0.00015) #print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , 0.0015) - assert sum(errs)/len(errs)/math.sqrt(dim) < 0.011 - assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.15 + #assert sum(errs)/len(errs)/math.sqrt(dim) < 0.011 + #assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.15 @pytest.mark.skip("Row scale has some bugs for ampere") def test_managed(): From eefbf60270497d0dd55b7abe18c519f0c75331f3 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sat, 8 Jul 2023 16:31:58 -0700 Subject: [PATCH 07/13] Turning optimization (float accumulation). 185 vs 50. --- csrc/kernels.cu | 50 ++++++++++++++-------------------------- tests/test_functional.py | 10 ++++---- 2 files changed, 22 insertions(+), 38 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 0a845d4..dd1f6f2 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3528,29 +3528,26 @@ template __global__ void kgemm_4bit_inference_naive(in // 64 packed 8bit values per warp AND each warp loads one output for B -> 1x128 * 128x1 // 4 warps -> 4 loads per iter // 1x128 * 128x4 -> 1x4 outputs - typedef cub::WarpReduce WarpReduce; + //typedef cub::WarpReduce WarpReduce; + typedef cub::WarpReduce WarpReduce; __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32]; const int warp_idx = threadIdx.x / 32; const int warp_lane = threadIdx.x % 32; const int row_B = (THREADS/32)*blockIdx.x + warp_idx; const int num_values_8bit = num_values_4bit/2; - T local_C = T(0); - - T lane_quant_value = nf4_data[warp_lane % 16]; + //T local_C = T(0.0f); + float local_C = 0.0f; unsigned char local_B_4bit[num_values_8bit]; T local_B[num_values_4bit]; T local_A[num_values_4bit]; - __shared__ T quant_map[16*THREADS]; - __shared__ T quant_map2[16]; + __shared__ T quant_map[16]; + T local_absmax = T(0.0f); - //for(int i = 0; i < 16; i++) - // quant_map[threadIdx.x + (i*blockDim.x)] = nf4_data[i]; - //__syncthreads(); - - for(int i = 0; i < 16; i++) - quant_map2[i] = nf4_data[i]; + for(int i = threadIdx.x; i < 16; i++) + quant_map[i] = nf4_data[i]; + __syncthreads(); // A: [1, K] // B: [N, K] @@ -3559,7 +3556,7 @@ template __global__ void kgemm_4bit_inference_naive(in int inner_idx_halved = inner_idx/2; int offset_B = ldb*row_B; int absidx = ((2*offset_B)+inner_idx)/blocksize; - T local_absmax = __ldg(&(absmax[absidx])); + local_absmax = __ldg(&(absmax[absidx])); if(row_B < M) { @@ -3576,25 +3573,11 @@ template __global__ void kgemm_4bit_inference_naive(in } } - if(inner_idx+(num_values_4bit*32) < K) + #pragma unroll + for(int k = 0; k < num_values_4bit; k++) { - // full warp is running - #pragma unroll - for(int k = 0; k < num_values_4bit; k++) - { - local_B[k*2] = __shfl_sync(0xffffffff, lane_quant_value, local_B_4bit[k] >> 4)*local_absmax; - local_B[k*2 + 1] = __shfl_sync(0xffffffff, lane_quant_value, local_B_4bit[k] & 0x0F)*local_absmax; - } - } - else - { - // part of the warp exited already - #pragma unroll - for(int k = 0; k < num_values_4bit; k++) - { - local_B[k*2] = quant_map2[(local_B_4bit[k] >> 4)]*local_absmax; - local_B[k*2 + 1] = quant_map2[(local_B_4bit[k] & 0x0F)]*local_absmax; - } + local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax; + local_B[k*2 + 1] = quant_map[local_B_4bit[k] & 0x0F]*local_absmax; } if(inner_idx+num_values_4bit) @@ -3603,6 +3586,7 @@ template __global__ void kgemm_4bit_inference_naive(in reinterpret_cast(local_A)[1] = reinterpret_cast(A)[inner_idx/(num_values_8bit/2) + 1]; reinterpret_cast(local_A)[2] = reinterpret_cast(A)[inner_idx/(num_values_8bit/2) + 2]; reinterpret_cast(local_A)[3] = reinterpret_cast(A)[inner_idx/(num_values_8bit/2) + 3]; + } else for(int k = 0; k < num_values_4bit; k++) @@ -3610,14 +3594,14 @@ template __global__ void kgemm_4bit_inference_naive(in #pragma unroll for(int k = 0; k < num_values_4bit; k++) - local_C += local_A[k]*local_B[k]; + local_C += (float)(local_A[k]*local_B[k]); } local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); if(row_B < M && warp_lane == 0) - out[row_B] = local_C; + out[row_B] = T(local_C); } diff --git a/tests/test_functional.py b/tests/test_functional.py index 70d5515..552ccaa 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2420,7 +2420,7 @@ def test_cutlass3_gemm(dtype): def test_gemm_4bit(dtype): print('') #for dim in [64, 128, 256, 512, 1024, 2048, 4096]: - for dim in [4096]: + for dim in [4*1024]: errs = [] relerrs = [] max_err = 0 @@ -2485,10 +2485,10 @@ def test_gemm_4bit(dtype): #print(dim, sum(errs)/len(errs)/math.sqrt(dim)) #print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim)) #print(dim, (max_err.item(), max_relerr.item())) - #print(sum(errs)/len(errs)/math.sqrt(dim) , 0.00015) - #print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , 0.0015) - #assert sum(errs)/len(errs)/math.sqrt(dim) < 0.011 - #assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.15 + print(sum(errs)/len(errs)/math.sqrt(dim) , 0.00015) + print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , 0.0015) + assert sum(errs)/len(errs)/math.sqrt(dim) < 0.011 + assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.15 @pytest.mark.skip("Row scale has some bugs for ampere") def test_managed(): From 4b88d69de76f4e876d71665f48392b4c12e48867 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 9 Jul 2023 12:04:09 -0700 Subject: [PATCH 08/13] Added abitrary data types; fixed a bug for small matrices. --- bitsandbytes/autograd/_functions.py | 2 +- bitsandbytes/functional.py | 70 ++++++++++++++++----- csrc/kernels.cu | 13 ++-- csrc/kernels.cuh | 2 +- csrc/ops.cu | 8 +-- csrc/ops.cuh | 2 +- csrc/pythonInterface.c | 16 ++--- tests/test_functional.py | 94 ++++++----------------------- 8 files changed, 98 insertions(+), 109 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index eeef93b..7848b7e 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -562,6 +562,6 @@ def matmul( def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None): assert quant_state is not None if A.numel() == A.shape[-1] and A.requires_grad == False: - return F.cutlass3_gemm(A, B.t(), out, state=quant_state) + return F.gemv_4bit(A, B.t(), out, state=quant_state) else: return MatMul4Bit.apply(A, B, out, bias, quant_state) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 95a15d5..e09b267 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -240,17 +240,19 @@ def create_normal_map(offset=0.9677083, use_extra_value=True): v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist() v2 = [0]*(256-15) ## we have 15 non-zero values in this data type v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() - v = v1 + v2 + v3 else: v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist() v2 = [0]*(256-14) ## we have 14 non-zero values in this data type v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() - v = v1 + v2 + v3 + + v = v1 + v2 + v3 values = torch.Tensor(v) values = values.sort().values values /= values.max() + assert values.numel() == 256 + return values def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8): @@ -710,6 +712,47 @@ def dequantize_blockwise( return out +def get_4bit_type(typename, device=None, blocksize=64): + if device is None: device = 'cuda' + data = None + if typename == 'nf4': + data = [-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, + -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, + 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, + 0.7229568362236023, 1.0] + elif typename == 'fp4': + # 0b000 = 0 + # 0b001 = 0.0625 + # 0b010 = 8 + # 0b011 = 12 + # 0b100 = 4 + # 0b101 = 6 + # 0b110 = 2 + # 0b111 = 3 + data = [0, 0.0625, 8.0, 12.0, 4.0, 6.0, 2.0, 3.0, -0, -0.0625, -8.0, -12.0, -4.0, -6.0, -2.0, -3.0] + elif typename == 'int4': + data = [7, 6, 5, 4, 3, 2, 1, 0, -0, -1, -2, -3, -4, -5, -6, -7] + elif typename == 'af4': + # Taken from: NF4 Isn't Information Theoretically Optimal (and that's Good) + # https://arxiv.org/abs/2306.06965 + if blocksize == 64: + data = [-1., -0.69441008, -0.51243739, -0.3736951, -0.25607552, -0.14982478, + -0.04934812, 0., 0.04273164, 0.12934483, 0.21961274, 0.31675666, + 0.42563882, 0.55496234, 0.72424863, 1.][::-1] + else: + raise NotImplementedError(f'4-bit AbnormalFloats currently only support blocksize 64.') + + if data is None: + raise NotImplementedError(f'Typename {typename} not supported') + + data = Tensor(data) + data /= data.abs().max() + assert data.numel() == 16 + + return data.to(device) + + + def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4') @@ -783,6 +826,8 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) + datatype = get_4bit_type(quant_type, device=A.device) + if compress_statistics: offset = absmax.mean() absmax -= offset @@ -790,9 +835,9 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz #qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256) qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) del absmax - state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type] + state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type, datatype] else: - state = [absmax, input_shape, A.dtype, blocksize, None, quant_type] + state = [absmax, input_shape, A.dtype, blocksize, None, quant_type, datatype] return out, state @@ -839,7 +884,7 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: shape = out.shape dtype = out.dtype else: - absmax, shape, dtype, blocksize, compressed_stats, quant_type = quant_state + absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state if compressed_stats is not None: @@ -1408,13 +1453,14 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 return sout -def cutlass3_gemm( +def gemv_4bit( A: Tensor, B: Tensor, out: Tensor = None, transposed_A=False, transposed_B=False, - state=None + state=None, + storage_type='nf4' ): #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) if state is None: @@ -1491,8 +1537,6 @@ def cutlass3_gemm( ldb = sA[2] ldc = m - ptr = CUBLAS_Context.get_instance().get_context(A.device) - # B^T @ A^T = C^T # [km, nk -> mn] #lda = ldb = ldc = 1 @@ -1514,15 +1558,11 @@ def cutlass3_gemm( if B.dtype == torch.uint8: if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) + lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) + lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) else: raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') - elif A.dtype == torch.float32: - lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) - elif A.dtype == torch.float16: - lib.cgemm_host_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) else: raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') diff --git a/csrc/kernels.cu b/csrc/kernels.cu index dd1f6f2..4131477 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3520,7 +3520,7 @@ template __global__ void kgemm_4bit_inference(int M, i } #define num_values_4bit 32 -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) { // per threadblock: @@ -3568,7 +3568,9 @@ template __global__ void kgemm_4bit_inference_naive(in { #pragma unroll for(int j = 0; j < (num_values_8bit); j++) - if((inner_idx/2) + j < K) + if((inner_idx_halved) + j < K) + local_B_4bit[j] = B[offset_B+inner_idx_halved + j]; + else local_B_4bit[j] = 0b01110111; } } @@ -3578,6 +3580,9 @@ template __global__ void kgemm_4bit_inference_naive(in { local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax; local_B[k*2 + 1] = quant_map[local_B_4bit[k] & 0x0F]*local_absmax; + + //if(threadIdx.x == 0) + //printf("%f %f %f %f\n", (float)local_B[k*2], (float)dDequantizeNF4(local_B_4bit[k] >> 4)*(float)local_absmax, (float)local_B[k*2]- ((float)dDequantizeNF4(local_B_4bit[k] >> 4)*(float)local_absmax), (float)local_absmax); } if(inner_idx+num_values_4bit) @@ -3773,8 +3778,8 @@ template __global__ void kgemm_4bit_inference(int M, int N, int K, ha template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 05d0715..d5349d6 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -125,7 +125,7 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kfunc(T *A, T *B, T value, long n); diff --git a/csrc/ops.cu b/csrc/ops.cu index 902129f..8bcee2c 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -729,12 +729,12 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } -template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) { int num_blocks = (m+3)/4; - kgemm_4bit_inference_naive<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + kgemm_4bit_inference_naive<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } template void func(T *A, T *B, T value, long n) @@ -757,8 +757,8 @@ template void func(float *A, float *B, float value, long n); template void func(float *A, float *B, float value, long n); template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive<__nv_bfloat16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive<__nv_bfloat16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); //template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index e4df195..699ff20 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -200,7 +200,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); template void func(T *A, T *B, T value, long n); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 4cbabae..b1a079f 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -28,11 +28,11 @@ void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int l void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } -void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) -{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } +void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize) +{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } -void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize) -{ gemm_4bit_inference_naive<__nv_bfloat16>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } +void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize) +{ gemm_4bit_inference_naive<__nv_bfloat16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } #define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func(A, B, value, n); } \ @@ -394,11 +394,11 @@ extern "C" CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) - void cgemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) - { gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } + void cgemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize) + { gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } - void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize) - { gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } + void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize) + { gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } #endif void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } diff --git a/tests/test_functional.py b/tests/test_functional.py index 552ccaa..54af27d 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1816,7 +1816,7 @@ def test_bench_matmul(batch, seq, model, hidden): linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() - F.cutlass3_gemm(A, B_nf4.t(), state=state_nf4) + F.gemv_4bit(A, B_nf4.t(), state=state_nf4) # warmup for i in range(iters): @@ -1849,7 +1849,7 @@ def test_bench_matmul(batch, seq, model, hidden): t0 = time.time() for i in range(iters): #bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4) - F.cutlass3_gemm(A, B_nf4.t(), state=state_nf4) + F.gemv_4bit(A, B_nf4.t(), state=state_nf4) torch.cuda.synchronize() print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) @@ -2351,76 +2351,14 @@ def test_normal_map_tree(): print(pivots) -#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) -@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) -def test_cutlass3_gemm(dtype): - debug = True - #for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: - #for dim in [4096, 5120, 6656, 8192]: - for dim in [4096]: - #for dim in [128+1]: - errs = [] - relerrs = [] - max_err = 0 - max_relerr = 0 - for i in range(100): - A = torch.randn(1, dim, dtype=dtype, device='cuda') - B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim) - #B = torch.randn(1, dim, dtype=dtype, device='cuda')/math.sqrt(dim) - - #print('') - #print(A) - #print(B.t()) - #A[:, :-1] = 0 - #B[:, :-1] = 0 - - - C1 = torch.matmul(A, B.t()) - C2 = F.cutlass3_gemm(A, B.t()) - - # tensor cores are non-deterministic - # so we need to analyze errors around the mean - # to test our implementation - err = torch.abs(C1-C2) - mag = torch.abs(C1)+1e-8 - relerr = err/mag - max_err = max(err.max(), max_err) - max_relerr = max(relerr.max(), max_relerr) - err = err.mean().item() - relerr = relerr.mean().item() - - errs.append(err) - relerrs.append(relerr) - - #if not debug and err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: - # print('') - # print(i, err, relerr) - # print(A.flatten()[-6:]) - # print(B.flatten()[-6:]) - # out = A.flatten()[-6:]*B.flatten()[-6:] - # print(out) - # print(out[:-1].sum()) - # print('='*80) - # print(C1.flatten()[-6:]) - # print(C2.flatten()[-6:]) - # #assert False, 'ERROR' - - c = int(C1.numel()*0.0014*(dim/256))+1 - - c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=not debug) - #print(c/math.sqrt(dim)) - print('') - print(dim, sum(errs)/len(errs)/math.sqrt(dim)) - print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim)) - print(dim, (max_err.item(), max_relerr.item())) - #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=['fp16', 'bf16']) #@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16']) -def test_gemm_4bit(dtype): +def test_gemv_4bit(dtype): print('') - #for dim in [64, 128, 256, 512, 1024, 2048, 4096]: - for dim in [4*1024]: + for dim in [64, 128, 256, 512, 1024, 2048, 4096]: + #for dim in [4*1024]: + #for dim in [1*16]: errs = [] relerrs = [] max_err = 0 @@ -2446,9 +2384,10 @@ def test_gemm_4bit(dtype): qB, state = F.quantize_nf4(B) F.dequantize_nf4(qB, state) - #C2 = bnb.matmul_4bit(A, qB.t(), state) - C2 = F.cutlass3_gemm(A, qB.t(), state=state) - C1 = torch.matmul(A, B.t()) + C2 = F.gemv_4bit(A, qB.t(), state=state) + C3 = torch.matmul(A, B.t()) + A.requires_grad = True + C1 = bnb.matmul_4bit(A, qB.t(), state) #print(state) #print(qB) @@ -2457,8 +2396,7 @@ def test_gemm_4bit(dtype): #print(A) #print(B) #print('='*89) - #print(C1) - #print(C2) + #print(C3.flatten()[-20:]) #print(C3) #print(C1.shape, C2.shape) @@ -2485,10 +2423,16 @@ def test_gemm_4bit(dtype): #print(dim, sum(errs)/len(errs)/math.sqrt(dim)) #print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim)) #print(dim, (max_err.item(), max_relerr.item())) + print(C1.flatten()[-20:]) + print(C2.flatten()[-20:]) print(sum(errs)/len(errs)/math.sqrt(dim) , 0.00015) print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , 0.0015) - assert sum(errs)/len(errs)/math.sqrt(dim) < 0.011 - assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.15 + if dtype == torch.float16: + assert sum(errs)/len(errs)/math.sqrt(dim) < 5e-5 + assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.0005 + else: + assert sum(errs)/len(errs)/math.sqrt(dim) < 3e-4 + assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.003 @pytest.mark.skip("Row scale has some bugs for ampere") def test_managed(): From 94168d79d74174ee4ba7c183e2cfc7dacc89c939 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 9 Jul 2023 14:46:19 -0700 Subject: [PATCH 09/13] Added FP4 fast inference support. --- bitsandbytes/autograd/_functions.py | 4 ++-- bitsandbytes/functional.py | 3 +-- csrc/kernels.cu | 6 ++---- tests/test_functional.py | 17 +++++++++-------- 4 files changed, 14 insertions(+), 16 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 7848b7e..22f89b1 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -509,7 +509,7 @@ class MatMul4Bit(torch.autograd.Function): # 1. Dequantize # 2. MatmulnN - output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype).t(), bias) + output = torch.nn.functional.linear(A, F.dequantize_4bit(B, state).to(A.dtype).t(), bias) # 3. Save state ctx.state = state @@ -540,7 +540,7 @@ class MatMul4Bit(torch.autograd.Function): # not supported by PyTorch. TODO: create work-around #if req_gradB: grad_B = torch.matmul(grad_output.t(), A) - if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(grad_output.dtype).t()) + if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t()) return grad_A, grad_B, None, grad_bias, None diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index e09b267..1f658ac 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1459,8 +1459,7 @@ def gemv_4bit( out: Tensor = None, transposed_A=False, transposed_B=False, - state=None, - storage_type='nf4' + state=None ): #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) if state is None: diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 4131477..1aaeb22 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3546,7 +3546,8 @@ template __global__ void kgemm_4bit_inference_naive(in T local_absmax = T(0.0f); for(int i = threadIdx.x; i < 16; i++) - quant_map[i] = nf4_data[i]; + quant_map[i] = datatype[i]; + __syncthreads(); // A: [1, K] @@ -3580,9 +3581,6 @@ template __global__ void kgemm_4bit_inference_naive(in { local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax; local_B[k*2 + 1] = quant_map[local_B_4bit[k] & 0x0F]*local_absmax; - - //if(threadIdx.x == 0) - //printf("%f %f %f %f\n", (float)local_B[k*2], (float)dDequantizeNF4(local_B_4bit[k] >> 4)*(float)local_absmax, (float)local_B[k*2]- ((float)dDequantizeNF4(local_B_4bit[k] >> 4)*(float)local_absmax), (float)local_absmax); } if(inner_idx+num_values_4bit) diff --git a/tests/test_functional.py b/tests/test_functional.py index 54af27d..68688ed 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2351,12 +2351,13 @@ def test_normal_map_tree(): print(pivots) +@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4']) #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=['fp16', 'bf16']) #@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16']) -def test_gemv_4bit(dtype): +def test_gemv_4bit(dtype, storage_type): print('') - for dim in [64, 128, 256, 512, 1024, 2048, 4096]: + for dim in [128, 256, 512, 1024, 2048, 4096]: #for dim in [4*1024]: #for dim in [1*16]: errs = [] @@ -2364,7 +2365,7 @@ def test_gemv_4bit(dtype): max_err = 0 max_relerr = 0 - for i in range(100): + for i in range(1): #A = torch.rand(2, 4092, dtype=dtype, device='cuda') #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') #A = torch.rand(1, 4096, dtype=dtype, device='cuda') @@ -2381,8 +2382,8 @@ def test_gemv_4bit(dtype): #A.flatten()[:-1] = 0 #B.flatten()[:-1] = 0 - qB, state = F.quantize_nf4(B) - F.dequantize_nf4(qB, state) + qB, state = F.quantize_4bit(B, quant_type=storage_type) + F.dequantize_4bit(qB, state) C2 = F.gemv_4bit(A, qB.t(), state=state) C3 = torch.matmul(A, B.t()) @@ -2396,7 +2397,6 @@ def test_gemv_4bit(dtype): #print(A) #print(B) #print('='*89) - #print(C3.flatten()[-20:]) #print(C3) #print(C1.shape, C2.shape) @@ -2425,8 +2425,9 @@ def test_gemv_4bit(dtype): #print(dim, (max_err.item(), max_relerr.item())) print(C1.flatten()[-20:]) print(C2.flatten()[-20:]) - print(sum(errs)/len(errs)/math.sqrt(dim) , 0.00015) - print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , 0.0015) + print(C3.flatten()[-20:]) + print(sum(errs)/len(errs)/math.sqrt(dim) , dim) + print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , dim) if dtype == torch.float16: assert sum(errs)/len(errs)/math.sqrt(dim) < 5e-5 assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.0005 From 0f0390acb2a6307c6a92bbef2ff095bd7cbcdc90 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 9 Jul 2023 15:32:03 -0700 Subject: [PATCH 10/13] Added double quantization support and tests. --- bitsandbytes/functional.py | 25 ++++++++++++++++++------- tests/test_functional.py | 31 ++++++++++++++++++++----------- 2 files changed, 38 insertions(+), 18 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 1f658ac..aa18925 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1461,16 +1461,25 @@ def gemv_4bit( transposed_B=False, state=None ): + prev_device = pre_call(A.device) #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) if state is None: - Bshape = B.shape - bout = Bshape[1] - else: - Bshape = state[1] - bout = Bshape[0] + raise ValueError(f'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )') + + Bshape = state[1] + bout = Bshape[0] + absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = state + if compressed_stats is not None: + offset, state2 = compressed_stats + absmax = dequantize_blockwise(absmax, state2) + absmax += offset + if out is None: out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device) + + + sA = A.shape sB = B.shape if transposed_A and len(sA) == 2: @@ -1557,14 +1566,16 @@ def gemv_4bit( if B.dtype == torch.uint8: if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) + lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) + lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) else: raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') else: raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') + post_call(prev_device) + return out diff --git a/tests/test_functional.py b/tests/test_functional.py index 68688ed..6dff784 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1776,7 +1776,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): print("partial matmul", time.time() - t0) -batch_size = 1 +batch_size = 5 seqdim = 1 values = [] #values.append((batch_size, seqdim, 768, 4 * 768)) @@ -1786,8 +1786,8 @@ values = [] #values.append((batch_size, seqdim, 2560, 4*2560)) #values.append((batch_size, seqdim, 4096, 4*4096)) #values.append((batch_size, seqdim, 5120, 4*5120)) -#values.append((batch_size, seqdim, 6656, 4*6656)) -values.append((batch_size, seqdim, 8192, 4*8192)) +values.append((batch_size, seqdim, 6656, 4*6656)) +#values.append((batch_size, seqdim, 8192, 4*8192)) #values.append((batch_size, seqdim, 5140, 4*5140)) #values.append((batch_size, seqdim, 12288, 4*12288)) names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values] @@ -1804,6 +1804,7 @@ def test_bench_matmul(batch, seq, model, hidden): B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True) B_nf4, state_nf4 = F.quantize_nf4(B) + B_nf4_c, state_nf4_c = F.quantize_nf4(B, compress_statistics=True) linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half() linear8bit.eval() @@ -1816,7 +1817,7 @@ def test_bench_matmul(batch, seq, model, hidden): linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() - F.gemv_4bit(A, B_nf4.t(), state=state_nf4) + bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4) # warmup for i in range(iters): @@ -1848,11 +1849,18 @@ def test_bench_matmul(batch, seq, model, hidden): torch.cuda.synchronize() t0 = time.time() for i in range(iters): - #bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4) - F.gemv_4bit(A, B_nf4.t(), state=state_nf4) + bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4) torch.cuda.synchronize() print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c) + torch.cuda.synchronize() + print( f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + + #torch.cuda.synchronize() #t0 = time.time() #for i in range(iters): @@ -2351,11 +2359,12 @@ def test_normal_map_tree(): print(pivots) +@pytest.mark.parametrize("double_quant", [True, False], ids=['DQ_True', 'DQ_False']) @pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4']) #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=['fp16', 'bf16']) #@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16']) -def test_gemv_4bit(dtype, storage_type): +def test_gemv_4bit(dtype, storage_type, double_quant): print('') for dim in [128, 256, 512, 1024, 2048, 4096]: #for dim in [4*1024]: @@ -2365,7 +2374,7 @@ def test_gemv_4bit(dtype, storage_type): max_err = 0 max_relerr = 0 - for i in range(1): + for i in range(100): #A = torch.rand(2, 4092, dtype=dtype, device='cuda') #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') #A = torch.rand(1, 4096, dtype=dtype, device='cuda') @@ -2382,11 +2391,11 @@ def test_gemv_4bit(dtype, storage_type): #A.flatten()[:-1] = 0 #B.flatten()[:-1] = 0 - qB, state = F.quantize_4bit(B, quant_type=storage_type) - F.dequantize_4bit(qB, state) + qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant) + #F.dequantize_4bit(qB, state) - C2 = F.gemv_4bit(A, qB.t(), state=state) C3 = torch.matmul(A, B.t()) + C2 = F.gemv_4bit(A, qB.t(), state=state) A.requires_grad = True C1 = bnb.matmul_4bit(A, qB.t(), state) From 6a905be5ced93c46e35b675fbdc73d40bb95d3ee Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 9 Jul 2023 15:34:02 -0700 Subject: [PATCH 11/13] Fixed a bug where gemv_4bit would return a wrongly sized tensor. --- bitsandbytes/functional.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index aa18925..78b5f4b 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1475,7 +1475,10 @@ def gemv_4bit( absmax += offset if out is None: - out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device) + if len(A.shape) == 3: + out = torch.zeros(size=(A.shape[0], A.shape[1], bout), dtype=A.dtype, device=A.device) + else: + out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device) From cef519c89ed04fdd6f3c09a672f8520532a89994 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 9 Jul 2023 17:16:50 -0700 Subject: [PATCH 12/13] Added test for Param4bit.to() and fixed double quant behavior. --- bitsandbytes/functional.py | 2 -- bitsandbytes/nn/modules.py | 6 +++--- tests/test_modules.py | 1 + 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 78b5f4b..c5514ed 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -831,8 +831,6 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz if compress_statistics: offset = absmax.mean() absmax -= offset - #code = create_custom_map().to(absmax.device) - #qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256) qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) del absmax state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type, datatype] diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 3284921..2407afb 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -188,9 +188,9 @@ class Params4bit(torch.nn.Parameter): #s[-2][1][0] = s[-2][1][0].to(device) # nested absmax # for 8-bit - s[-2][0] = s[-2][0].to(device) # offset - s[-2][1][0] = s[-2][1][0].to(device) # nested quantiation state statitics - s[-2][1][1] = s[-2][1][1].to(device) # nested quantiation codebook + s[-3][0] = s[-3][0].to(device) # offset + s[-3][1][0] = s[-3][1][0].to(device) # nested quantiation state statitics + s[-3][1][1] = s[-3][1][1].to(device) # nested quantiation codebook new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, quant_state=self.quant_state, blocksize=self.blocksize, compress_statistics=self.compress_statistics, diff --git a/tests/test_modules.py b/tests/test_modules.py index d0a9051..a187484 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -535,6 +535,7 @@ def test_kbit_backprop(module): kbit[1].bias.detach().copy_(ref[1].bias) ref = ref.half().cuda() kbit = kbit.half().cuda() + kbit = kbit.half().to('cuda') errs1 = [] errs2 = [] From 5fab6734424a78a2a4594525386cd84feb67fb50 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 9 Jul 2023 21:06:01 -0700 Subject: [PATCH 13/13] Added fp32 compute type for gemv_4bit. --- bitsandbytes/functional.py | 96 ++++++-------------------------------- csrc/kernels.cu | 31 ++++++++---- csrc/kernels.cuh | 2 +- csrc/ops.cu | 10 ++-- csrc/ops.cuh | 2 +- csrc/pythonInterface.c | 10 +++- tests/test_functional.py | 8 ++-- 7 files changed, 55 insertions(+), 104 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index c5514ed..1972462 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1464,6 +1464,9 @@ def gemv_4bit( if state is None: raise ValueError(f'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )') + if A.numel() != A.shape[-1]: + raise ValueError(f'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]') + Bshape = state[1] bout = Bshape[0] absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = state @@ -1474,90 +1477,17 @@ def gemv_4bit( if out is None: if len(A.shape) == 3: - out = torch.zeros(size=(A.shape[0], A.shape[1], bout), dtype=A.dtype, device=A.device) + out = torch.empty(size=(A.shape[0], A.shape[1], bout), dtype=A.dtype, device=A.device) else: - out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device) + out = torch.empty(size=(A.shape[0], bout), dtype=A.dtype, device=A.device) - - - - sA = A.shape - sB = B.shape - if transposed_A and len(sA) == 2: - sA = (sA[1], sA[0]) - elif transposed_A and len(sA) == 3: - sA = (sA[0], sA[2], sA[0]) - if transposed_B and len(sB) == 2: - sB = (sB[1], sB[0]) - elif transposed_B and len(sB) == 3: - sB = (sB[0], sB[2], sB[0]) - # this is a mess: cuBLAS expect column major, but PyTorch is row major. - # So to perform the matrix multiplication, we have to treat A, B, and C matrices - # (transpose of row major is column major) - # This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these - - # matrices in the input arguments for cuBLAS - # column major: A @ B = C: [m, k] @ [k, n] = [m, n] - # row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n] - # column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m] - if len(sB) == 2: - if B.stride()[0] == B.shape[1]: - transposed_B = False - elif B.stride()[1] == B.shape[0]: - transposed_B = True - if len(A.shape) == 2: - if A.stride()[0] == A.shape[1]: - transposed_A = False - elif A.stride()[1] == A.shape[0]: - transposed_A = True - else: - if A.stride()[1] == A.shape[2]: - transposed_A = False - elif A.stride()[2] == A.shape[1]: - transposed_A = True - - if len(sA) == 2: - n = sA[0] - ldb = A.stride()[1 if transposed_A else 0] - elif len(sA) == 3 and len(sB) == 2: - n = sA[0] * sA[1] - ldb = sA[2] - - m = sB[1] - k = sB[0] - lda = B.stride()[0] - ldc = sB[1] - elif len(sB) == 3: - # special case - assert len(sA) == 3 - if not (sA[0] == sB[0] and sA[1] == sB[1]): - raise ValueError( - f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}" - ) - - transposed_A = True - transposed_B = False - - m = sB[2] - n = sA[2] - k = sB[0] * sB[1] - - lda = n - ldb = sA[2] - ldc = m - - # B^T @ A^T = C^T - # [km, nk -> mn] - #lda = ldb = ldc = 1 - #lda = 1 - if state is not None: - m = Bshape[0] - k = Bshape[1] - lda = Bshape[0] - ldc = Bshape[0] - ldb = (ldb+1)//2 - #print(m, n, k, lda, ldb, ldc) - is_on_gpu([B, A, out]) + n = 1 + m = Bshape[0] + k = Bshape[1] + lda = Bshape[0] + ldc = Bshape[0] + ldb = (A.shape[-1]+1)//2 + is_on_gpu([B, A, out, absmax, state[-1]]) m = ct.c_int32(m) n = ct.c_int32(n) k = ct.c_int32(k) @@ -1570,6 +1500,8 @@ def gemv_4bit( lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) elif A.dtype == torch.bfloat16: lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) + elif A.dtype == torch.float32: + lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) else: raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') else: diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 1aaeb22..4b05672 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3520,7 +3520,7 @@ template __global__ void kgemm_4bit_inference(int M, i } #define num_values_4bit 32 -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) { // per threadblock: @@ -3528,7 +3528,6 @@ template __global__ void kgemm_4bit_inference_naive(in // 64 packed 8bit values per warp AND each warp loads one output for B -> 1x128 * 128x1 // 4 warps -> 4 loads per iter // 1x128 * 128x4 -> 1x4 outputs - //typedef cub::WarpReduce WarpReduce; typedef cub::WarpReduce WarpReduce; __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32]; @@ -3536,7 +3535,6 @@ template __global__ void kgemm_4bit_inference_naive(in const int warp_lane = threadIdx.x % 32; const int row_B = (THREADS/32)*blockIdx.x + warp_idx; const int num_values_8bit = num_values_4bit/2; - //T local_C = T(0.0f); float local_C = 0.0f; unsigned char local_B_4bit[num_values_8bit]; @@ -3585,10 +3583,24 @@ template __global__ void kgemm_4bit_inference_naive(in if(inner_idx+num_values_4bit) { - reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_8bit/2) + 0]; - reinterpret_cast(local_A)[1] = reinterpret_cast(A)[inner_idx/(num_values_8bit/2) + 1]; - reinterpret_cast(local_A)[2] = reinterpret_cast(A)[inner_idx/(num_values_8bit/2) + 2]; - reinterpret_cast(local_A)[3] = reinterpret_cast(A)[inner_idx/(num_values_8bit/2) + 3]; + if(BITS==16) + { + reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/4) + 0]; + reinterpret_cast(local_A)[1] = reinterpret_cast(A)[inner_idx/(num_values_4bit/4) + 1]; + reinterpret_cast(local_A)[2] = reinterpret_cast(A)[inner_idx/(num_values_4bit/4) + 2]; + reinterpret_cast(local_A)[3] = reinterpret_cast(A)[inner_idx/(num_values_4bit/4) + 3]; + } + else + { + reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + 0]; + reinterpret_cast(local_A)[1] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + 1]; + reinterpret_cast(local_A)[2] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + 2]; + reinterpret_cast(local_A)[3] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + 3]; + reinterpret_cast(local_A)[4] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + 4]; + reinterpret_cast(local_A)[5] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + 5]; + reinterpret_cast(local_A)[6] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + 6]; + reinterpret_cast(local_A)[7] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + 7]; + } } else @@ -3776,8 +3788,9 @@ template __global__ void kgemm_4bit_inference(int M, int N, int K, ha template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128, 16>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index d5349d6..a7fe3d7 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -125,7 +125,7 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kfunc(T *A, T *B, T value, long n); diff --git a/csrc/ops.cu b/csrc/ops.cu index 8bcee2c..b524e0e 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -729,12 +729,12 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } -template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) { int num_blocks = (m+3)/4; - kgemm_4bit_inference_naive<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); + kgemm_4bit_inference_naive<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } template void func(T *A, T *B, T value, long n) @@ -757,8 +757,10 @@ template void func(float *A, float *B, float value, long n); template void func(float *A, float *B, float value, long n); template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive<__nv_bfloat16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); + //template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 699ff20..f37b3b3 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -200,7 +200,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); template void func(T *A, T *B, T value, long n); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index b1a079f..0aa82fe 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -29,10 +29,13 @@ void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, floa { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize) -{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } +{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize) -{ gemm_4bit_inference_naive<__nv_bfloat16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } +{ gemm_4bit_inference_naive<__nv_bfloat16, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } + +void gemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize) +{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } #define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func(A, B, value, n); } \ @@ -400,6 +403,9 @@ extern "C" void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize) { gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } + void cgemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize) + { gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } + #endif void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } diff --git a/tests/test_functional.py b/tests/test_functional.py index 6dff784..34552cb 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1776,7 +1776,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): print("partial matmul", time.time() - t0) -batch_size = 5 +batch_size = 1 seqdim = 1 values = [] #values.append((batch_size, seqdim, 768, 4 * 768)) @@ -1793,7 +1793,7 @@ values.append((batch_size, seqdim, 6656, 4*6656)) names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values] @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) def test_bench_matmul(batch, seq, model, hidden): - iters = 80 + iters = 1000 formatB = F.get_special_format_str() A = torch.randn(batch, seq, model, device="cuda").half() @@ -2361,9 +2361,7 @@ def test_normal_map_tree(): @pytest.mark.parametrize("double_quant", [True, False], ids=['DQ_True', 'DQ_False']) @pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4']) -#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=['fp16', 'bf16']) -#@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16']) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32']) def test_gemv_4bit(dtype, storage_type, double_quant): print('') for dim in [128, 256, 512, 1024, 2048, 4096]: