Added FP4 fast inference support.

This commit is contained in:
Tim Dettmers 2023-07-09 14:46:19 -07:00
parent 4b88d69de7
commit 94168d79d7
4 changed files with 14 additions and 16 deletions

View File

@ -509,7 +509,7 @@ class MatMul4Bit(torch.autograd.Function):
# 1. Dequantize
# 2. MatmulnN
output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype).t(), bias)
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, state).to(A.dtype).t(), bias)
# 3. Save state
ctx.state = state
@ -540,7 +540,7 @@ class MatMul4Bit(torch.autograd.Function):
# not supported by PyTorch. TODO: create work-around
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(grad_output.dtype).t())
if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())
return grad_A, grad_B, None, grad_bias, None

View File

@ -1459,8 +1459,7 @@ def gemv_4bit(
out: Tensor = None,
transposed_A=False,
transposed_B=False,
state=None,
storage_type='nf4'
state=None
):
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
if state is None:

View File

@ -3546,7 +3546,8 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
T local_absmax = T(0.0f);
for(int i = threadIdx.x; i < 16; i++)
quant_map[i] = nf4_data[i];
quant_map[i] = datatype[i];
__syncthreads();
// A: [1, K]
@ -3580,9 +3581,6 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
{
local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax;
local_B[k*2 + 1] = quant_map[local_B_4bit[k] & 0x0F]*local_absmax;
//if(threadIdx.x == 0)
//printf("%f %f %f %f\n", (float)local_B[k*2], (float)dDequantizeNF4(local_B_4bit[k] >> 4)*(float)local_absmax, (float)local_B[k*2]- ((float)dDequantizeNF4(local_B_4bit[k] >> 4)*(float)local_absmax), (float)local_absmax);
}
if(inner_idx+num_values_4bit)

View File

@ -2351,12 +2351,13 @@ def test_normal_map_tree():
print(pivots)
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=['fp16', 'bf16'])
#@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16'])
def test_gemv_4bit(dtype):
def test_gemv_4bit(dtype, storage_type):
print('')
for dim in [64, 128, 256, 512, 1024, 2048, 4096]:
for dim in [128, 256, 512, 1024, 2048, 4096]:
#for dim in [4*1024]:
#for dim in [1*16]:
errs = []
@ -2364,7 +2365,7 @@ def test_gemv_4bit(dtype):
max_err = 0
max_relerr = 0
for i in range(100):
for i in range(1):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
@ -2381,8 +2382,8 @@ def test_gemv_4bit(dtype):
#A.flatten()[:-1] = 0
#B.flatten()[:-1] = 0
qB, state = F.quantize_nf4(B)
F.dequantize_nf4(qB, state)
qB, state = F.quantize_4bit(B, quant_type=storage_type)
F.dequantize_4bit(qB, state)
C2 = F.gemv_4bit(A, qB.t(), state=state)
C3 = torch.matmul(A, B.t())
@ -2396,7 +2397,6 @@ def test_gemv_4bit(dtype):
#print(A)
#print(B)
#print('='*89)
#print(C3.flatten()[-20:])
#print(C3)
#print(C1.shape, C2.shape)
@ -2425,8 +2425,9 @@ def test_gemv_4bit(dtype):
#print(dim, (max_err.item(), max_relerr.item()))
print(C1.flatten()[-20:])
print(C2.flatten()[-20:])
print(sum(errs)/len(errs)/math.sqrt(dim) , 0.00015)
print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , 0.0015)
print(C3.flatten()[-20:])
print(sum(errs)/len(errs)/math.sqrt(dim) , dim)
print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , dim)
if dtype == torch.float16:
assert sum(errs)/len(errs)/math.sqrt(dim) < 5e-5
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.0005