Added lookup table.
This commit is contained in:
parent
ac5550a023
commit
b7f04e2a20
4
Makefile
4
Makefile
|
@ -47,8 +47,8 @@ CC_cublasLt110 := -gencode arch=compute_75,code=sm_75
|
||||||
CC_cublasLt110 += -gencode arch=compute_80,code=sm_80
|
CC_cublasLt110 += -gencode arch=compute_80,code=sm_80
|
||||||
|
|
||||||
CC_cublasLt111 := -gencode arch=compute_75,code=sm_75
|
CC_cublasLt111 := -gencode arch=compute_75,code=sm_75
|
||||||
CC_cublasLt111 += -gencode arch=compute_80,code=sm_80
|
#CC_cublasLt111 += -gencode arch=compute_80,code=sm_80
|
||||||
CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
|
#CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
|
||||||
|
|
||||||
CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89
|
CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89
|
||||||
CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90
|
CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90
|
||||||
|
|
|
@ -3297,6 +3297,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0};
|
||||||
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
|
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
|
||||||
{
|
{
|
||||||
|
|
||||||
|
@ -3308,6 +3309,12 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
|
||||||
const int half_warp_lane = threadIdx.x % 16;
|
const int half_warp_lane = threadIdx.x % 16;
|
||||||
const int batch_size_warps = (WARPS-1)*2;
|
const int batch_size_warps = (WARPS-1)*2;
|
||||||
|
|
||||||
|
T quant_map[16];
|
||||||
|
|
||||||
|
#pragma unroll 16
|
||||||
|
for(int i = 0; i < 16; i++)
|
||||||
|
quant_map[i] = nf4_data[i];
|
||||||
|
|
||||||
T local_A[2];
|
T local_A[2];
|
||||||
T local_B[64];
|
T local_B[64];
|
||||||
unsigned char local_B_4bit[32];
|
unsigned char local_B_4bit[32];
|
||||||
|
@ -3410,6 +3417,8 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
|
||||||
{
|
{
|
||||||
local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx);
|
local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx);
|
||||||
local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx);
|
local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx);
|
||||||
|
//local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx);
|
||||||
|
//local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2297,7 +2297,8 @@ def test_4bit_compressed_stats(quant_type):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
||||||
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
|
#@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
|
||||||
|
@pytest.mark.parametrize("quant_type", ['nf4'])
|
||||||
def test_bench_4bit_dequant(quant_type):
|
def test_bench_4bit_dequant(quant_type):
|
||||||
blocksize = 256
|
blocksize = 256
|
||||||
a = torch.rand(1024*12*4, 1024*12, device='cuda').half()
|
a = torch.rand(1024*12*4, 1024*12, device='cuda').half()
|
||||||
|
@ -2311,7 +2312,7 @@ def test_bench_4bit_dequant(quant_type):
|
||||||
#print(max_theoretical_s*1e6)
|
#print(max_theoretical_s*1e6)
|
||||||
b = torch.randn(128, 1024*12, device='cuda').half()
|
b = torch.randn(128, 1024*12, device='cuda').half()
|
||||||
|
|
||||||
iters = 5
|
iters = 100
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
|
@ -2438,9 +2439,11 @@ def test_gemm_4bit(dtype):
|
||||||
C3 = torch.matmul(A, B.t())
|
C3 = torch.matmul(A, B.t())
|
||||||
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
|
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
|
||||||
C1 = bnb.matmul_4bit(A, qB.t(), state)
|
C1 = bnb.matmul_4bit(A, qB.t(), state)
|
||||||
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
|
|
||||||
|
|
||||||
print(C1.shape, C2.shape)
|
print(C1)
|
||||||
|
print(C2)
|
||||||
|
|
||||||
|
#print(C1.shape, C2.shape)
|
||||||
|
|
||||||
# tensor cores are non-deterministic
|
# tensor cores are non-deterministic
|
||||||
# so we need to analyze errors around the mean
|
# so we need to analyze errors around the mean
|
||||||
|
@ -2452,6 +2455,7 @@ def test_gemm_4bit(dtype):
|
||||||
max_relerr = max(relerr.max(), max_relerr)
|
max_relerr = max(relerr.max(), max_relerr)
|
||||||
err = err.mean().item()
|
err = err.mean().item()
|
||||||
relerr = relerr.mean().item()
|
relerr = relerr.mean().item()
|
||||||
|
print(err)
|
||||||
|
|
||||||
errs.append(err)
|
errs.append(err)
|
||||||
relerrs.append(relerr)
|
relerrs.append(relerr)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user