Added FP4 fast inference support.
This commit is contained in:
parent
4b88d69de7
commit
94168d79d7
|
@ -509,7 +509,7 @@ class MatMul4Bit(torch.autograd.Function):
|
|||
|
||||
# 1. Dequantize
|
||||
# 2. MatmulnN
|
||||
output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype).t(), bias)
|
||||
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, state).to(A.dtype).t(), bias)
|
||||
|
||||
# 3. Save state
|
||||
ctx.state = state
|
||||
|
@ -540,7 +540,7 @@ class MatMul4Bit(torch.autograd.Function):
|
|||
|
||||
# not supported by PyTorch. TODO: create work-around
|
||||
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
|
||||
if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(grad_output.dtype).t())
|
||||
if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())
|
||||
|
||||
return grad_A, grad_B, None, grad_bias, None
|
||||
|
||||
|
|
|
@ -1459,8 +1459,7 @@ def gemv_4bit(
|
|||
out: Tensor = None,
|
||||
transposed_A=False,
|
||||
transposed_B=False,
|
||||
state=None,
|
||||
storage_type='nf4'
|
||||
state=None
|
||||
):
|
||||
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
|
||||
if state is None:
|
||||
|
|
|
@ -3546,7 +3546,8 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
|
|||
T local_absmax = T(0.0f);
|
||||
|
||||
for(int i = threadIdx.x; i < 16; i++)
|
||||
quant_map[i] = nf4_data[i];
|
||||
quant_map[i] = datatype[i];
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// A: [1, K]
|
||||
|
@ -3580,9 +3581,6 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
|
|||
{
|
||||
local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax;
|
||||
local_B[k*2 + 1] = quant_map[local_B_4bit[k] & 0x0F]*local_absmax;
|
||||
|
||||
//if(threadIdx.x == 0)
|
||||
//printf("%f %f %f %f\n", (float)local_B[k*2], (float)dDequantizeNF4(local_B_4bit[k] >> 4)*(float)local_absmax, (float)local_B[k*2]- ((float)dDequantizeNF4(local_B_4bit[k] >> 4)*(float)local_absmax), (float)local_absmax);
|
||||
}
|
||||
|
||||
if(inner_idx+num_values_4bit)
|
||||
|
|
|
@ -2351,12 +2351,13 @@ def test_normal_map_tree():
|
|||
print(pivots)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
|
||||
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=['fp16', 'bf16'])
|
||||
#@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16'])
|
||||
def test_gemv_4bit(dtype):
|
||||
def test_gemv_4bit(dtype, storage_type):
|
||||
print('')
|
||||
for dim in [64, 128, 256, 512, 1024, 2048, 4096]:
|
||||
for dim in [128, 256, 512, 1024, 2048, 4096]:
|
||||
#for dim in [4*1024]:
|
||||
#for dim in [1*16]:
|
||||
errs = []
|
||||
|
@ -2364,7 +2365,7 @@ def test_gemv_4bit(dtype):
|
|||
max_err = 0
|
||||
max_relerr = 0
|
||||
|
||||
for i in range(100):
|
||||
for i in range(1):
|
||||
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
|
||||
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
|
||||
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
|
||||
|
@ -2381,8 +2382,8 @@ def test_gemv_4bit(dtype):
|
|||
#A.flatten()[:-1] = 0
|
||||
#B.flatten()[:-1] = 0
|
||||
|
||||
qB, state = F.quantize_nf4(B)
|
||||
F.dequantize_nf4(qB, state)
|
||||
qB, state = F.quantize_4bit(B, quant_type=storage_type)
|
||||
F.dequantize_4bit(qB, state)
|
||||
|
||||
C2 = F.gemv_4bit(A, qB.t(), state=state)
|
||||
C3 = torch.matmul(A, B.t())
|
||||
|
@ -2396,7 +2397,6 @@ def test_gemv_4bit(dtype):
|
|||
#print(A)
|
||||
#print(B)
|
||||
#print('='*89)
|
||||
#print(C3.flatten()[-20:])
|
||||
#print(C3)
|
||||
|
||||
#print(C1.shape, C2.shape)
|
||||
|
@ -2425,8 +2425,9 @@ def test_gemv_4bit(dtype):
|
|||
#print(dim, (max_err.item(), max_relerr.item()))
|
||||
print(C1.flatten()[-20:])
|
||||
print(C2.flatten()[-20:])
|
||||
print(sum(errs)/len(errs)/math.sqrt(dim) , 0.00015)
|
||||
print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , 0.0015)
|
||||
print(C3.flatten()[-20:])
|
||||
print(sum(errs)/len(errs)/math.sqrt(dim) , dim)
|
||||
print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , dim)
|
||||
if dtype == torch.float16:
|
||||
assert sum(errs)/len(errs)/math.sqrt(dim) < 5e-5
|
||||
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.0005
|
||||
|
|
Loading…
Reference in New Issue
Block a user