Added lookup table.

This commit is contained in:
Tim Dettmers 2023-05-30 20:07:05 -07:00
parent ac5550a023
commit b7f04e2a20
3 changed files with 19 additions and 6 deletions

View File

@ -47,8 +47,8 @@ CC_cublasLt110 := -gencode arch=compute_75,code=sm_75
CC_cublasLt110 += -gencode arch=compute_80,code=sm_80
CC_cublasLt111 := -gencode arch=compute_75,code=sm_75
CC_cublasLt111 += -gencode arch=compute_80,code=sm_80
CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
#CC_cublasLt111 += -gencode arch=compute_80,code=sm_80
#CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89
CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90

View File

@ -3297,6 +3297,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
#endif
}
__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0};
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
{
@ -3308,6 +3309,12 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
const int half_warp_lane = threadIdx.x % 16;
const int batch_size_warps = (WARPS-1)*2;
T quant_map[16];
#pragma unroll 16
for(int i = 0; i < 16; i++)
quant_map[i] = nf4_data[i];
T local_A[2];
T local_B[64];
unsigned char local_B_4bit[32];
@ -3410,6 +3417,8 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
{
local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx);
local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx);
//local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx);
//local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx);
}
}

View File

@ -2297,7 +2297,8 @@ def test_4bit_compressed_stats(quant_type):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
#@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
@pytest.mark.parametrize("quant_type", ['nf4'])
def test_bench_4bit_dequant(quant_type):
blocksize = 256
a = torch.rand(1024*12*4, 1024*12, device='cuda').half()
@ -2311,7 +2312,7 @@ def test_bench_4bit_dequant(quant_type):
#print(max_theoretical_s*1e6)
b = torch.randn(128, 1024*12, device='cuda').half()
iters = 5
iters = 100
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
@ -2438,9 +2439,11 @@ def test_gemm_4bit(dtype):
C3 = torch.matmul(A, B.t())
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
C1 = bnb.matmul_4bit(A, qB.t(), state)
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
print(C1.shape, C2.shape)
print(C1)
print(C2)
#print(C1.shape, C2.shape)
# tensor cores are non-deterministic
# so we need to analyze errors around the mean
@ -2452,6 +2455,7 @@ def test_gemm_4bit(dtype):
max_relerr = max(relerr.max(), max_relerr)
err = err.mean().item()
relerr = relerr.mean().item()
print(err)
errs.append(err)
relerrs.append(relerr)