diff --git a/Makefile b/Makefile index 5fa1f17..2cbb1b9 100644 --- a/Makefile +++ b/Makefile @@ -47,8 +47,8 @@ CC_cublasLt110 := -gencode arch=compute_75,code=sm_75 CC_cublasLt110 += -gencode arch=compute_80,code=sm_80 CC_cublasLt111 := -gencode arch=compute_75,code=sm_75 -CC_cublasLt111 += -gencode arch=compute_80,code=sm_80 -CC_cublasLt111 += -gencode arch=compute_86,code=sm_86 +#CC_cublasLt111 += -gencode arch=compute_80,code=sm_80 +#CC_cublasLt111 += -gencode arch=compute_86,code=sm_86 CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89 CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90 diff --git a/csrc/kernels.cu b/csrc/kernels.cu index ab12c37..7a752cb 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3297,6 +3297,7 @@ template __global__ void gemm_device(int M, #endif } +__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}; template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) { @@ -3308,6 +3309,12 @@ template __global__ void kgemm_4bit_inference(int M, i const int half_warp_lane = threadIdx.x % 16; const int batch_size_warps = (WARPS-1)*2; + T quant_map[16]; + + #pragma unroll 16 + for(int i = 0; i < 16; i++) + quant_map[i] = nf4_data[i]; + T local_A[2]; T local_B[64]; unsigned char local_B_4bit[32]; @@ -3410,6 +3417,8 @@ template __global__ void kgemm_4bit_inference(int M, i { local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); + //local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx); + //local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx); } } diff --git a/tests/test_functional.py b/tests/test_functional.py index cc58324..29b82e6 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2297,7 +2297,8 @@ def test_4bit_compressed_stats(quant_type): @pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") -@pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) +#@pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) +@pytest.mark.parametrize("quant_type", ['nf4']) def test_bench_4bit_dequant(quant_type): blocksize = 256 a = torch.rand(1024*12*4, 1024*12, device='cuda').half() @@ -2311,7 +2312,7 @@ def test_bench_4bit_dequant(quant_type): #print(max_theoretical_s*1e6) b = torch.randn(128, 1024*12, device='cuda').half() - iters = 5 + iters = 100 torch.cuda.synchronize() t0 = time.time() for i in range(iters): @@ -2438,9 +2439,11 @@ def test_gemm_4bit(dtype): C3 = torch.matmul(A, B.t()) C2 = F.cutlass3_gemm(A, qB.t(), state=state) C1 = bnb.matmul_4bit(A, qB.t(), state) - C2 = F.cutlass3_gemm(A, qB.t(), state=state) - print(C1.shape, C2.shape) + print(C1) + print(C2) + + #print(C1.shape, C2.shape) # tensor cores are non-deterministic # so we need to analyze errors around the mean @@ -2452,6 +2455,7 @@ def test_gemm_4bit(dtype): max_relerr = max(relerr.max(), max_relerr) err = err.mean().item() relerr = relerr.mean().item() + print(err) errs.append(err) relerrs.append(relerr)