Added more extensive gemv tests; blocksize guard for gemv.
This commit is contained in:
parent
b8da4a165a
commit
ba51d95d43
|
@ -3,6 +3,7 @@ import warnings
|
|||
from dataclasses import dataclass
|
||||
from functools import reduce # Required in Python 3
|
||||
from typing import Tuple, Optional, List
|
||||
from warnings import warn
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -565,6 +566,11 @@ def matmul(
|
|||
def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None):
|
||||
assert quant_state is not None
|
||||
if A.numel() == A.shape[-1] and A.requires_grad == False:
|
||||
return F.gemv_4bit(A, B.t(), out, state=quant_state)
|
||||
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state
|
||||
if A.shape[-1] % blocksize != 0:
|
||||
warn(f'Some matrices hidden dimension is not a multiple of {blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}')
|
||||
return MatMul4Bit.apply(A, B, out, bias, quant_state)
|
||||
else:
|
||||
return F.gemv_4bit(A, B.t(), out, state=quant_state)
|
||||
else:
|
||||
return MatMul4Bit.apply(A, B, out, bias, quant_state)
|
||||
|
|
|
@ -1504,6 +1504,7 @@ def gemv_4bit(
|
|||
lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
||||
else:
|
||||
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
|
||||
|
||||
|
|
|
@ -222,6 +222,7 @@ __device__ half dhDequantizeNF4(unsigned char val)
|
|||
|
||||
__device__ float dDequantizeNF4(unsigned char val)
|
||||
{
|
||||
|
||||
// the values for this tree was generated by test_normal_map_tree
|
||||
// in the file tests/test_functional.py
|
||||
if((val & 0b1000) == 8)
|
||||
|
@ -3526,10 +3527,9 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
|
|||
{
|
||||
|
||||
// per threadblock:
|
||||
// load step-by-step in chunks of [64,warps]: 1x64 * [64,warps] -> [1,warps]
|
||||
// 64 packed 8bit values per warp AND each warp loads one output for B -> 1x128 * 128x1
|
||||
// load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps]
|
||||
// 4 warps -> 4 loads per iter
|
||||
// 1x128 * 128x4 -> 1x4 outputs
|
||||
// 1x32 * 32x4 -> 1x4 outputs per thread block
|
||||
typedef cub::WarpReduce<float> WarpReduce;
|
||||
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32];
|
||||
|
||||
|
@ -3547,7 +3547,6 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
|
|||
|
||||
for(int i = threadIdx.x; i < 16; i++)
|
||||
quant_map[i] = T(datatype[i]);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// A: [1, K]
|
||||
|
@ -3563,6 +3562,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
|
|||
{
|
||||
if((inner_idx_halved + num_values_8bit) < (K/2))
|
||||
{
|
||||
// this is the most important for performance considerations
|
||||
reinterpret_cast<int4(&)[num_values_8bit]>(local_B_4bit)[0] = reinterpret_cast<int4*>(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)];
|
||||
}
|
||||
else
|
||||
|
@ -3597,6 +3597,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
|
|||
|
||||
if(inner_idx+num_values_4bit < K)
|
||||
{
|
||||
// this is also relatively important for performance
|
||||
if(BITS==16)
|
||||
{
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/4) + 0];
|
||||
|
@ -3618,6 +3619,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
|
|||
|
||||
}
|
||||
else
|
||||
#pragma unroll
|
||||
for(int k = 0; k < num_values_4bit; k++)
|
||||
if(inner_idx + k < K)
|
||||
local_A[k] = A[inner_idx + k];
|
||||
|
@ -3625,6 +3627,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
|
|||
local_A[k] = T(0.0f);
|
||||
|
||||
|
||||
// accumulate in float; small performance hit for Ampere, but lower error for outputs
|
||||
#pragma unroll
|
||||
for(int k = 0; k < num_values_4bit; k++)
|
||||
{
|
||||
|
|
|
@ -735,6 +735,7 @@ template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int
|
|||
int num_blocks = (m+3)/4;
|
||||
|
||||
kgemm_4bit_inference_naive<T, 128, BITS><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template <typename T, int FUNC> void func(T *A, T *B, T value, long n)
|
||||
|
|
|
@ -2262,7 +2262,7 @@ def test_fp4_quant(dtype):
|
|||
A2 = F.dequantize_fp4(qa, SA)
|
||||
|
||||
err = (A1 - A2).abs().float()
|
||||
relerr = (err/A1.abs().float()).mean()
|
||||
relerr = (err/(A1.abs().float()+1e-8)).mean()
|
||||
idx = err > 1.0
|
||||
err = err.mean()
|
||||
|
||||
|
@ -2361,91 +2361,133 @@ def test_normal_map_tree():
|
|||
|
||||
@pytest.mark.parametrize("double_quant", [True, False], ids=['DQ_True', 'DQ_False'])
|
||||
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
|
||||
@pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'], ids=['fc1', 'fc2', 'attn', 'attn_packed'])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
|
||||
def test_gemv_4bit(dtype, storage_type, double_quant):
|
||||
print('')
|
||||
for dim in [128, 256, 512, 1024, 2048, 4096]:
|
||||
def test_gemv_4bit(dtype, storage_type, double_quant, kind):
|
||||
for dim in [128, 256, 512, 1024, 2048, 4096, 6144]:
|
||||
#for dim in [4*1024]:
|
||||
#for dim in [1*16]:
|
||||
errs = []
|
||||
relerrs = []
|
||||
max_err = 0
|
||||
max_relerr = 0
|
||||
#for dim in [1*128]:
|
||||
errs1 = []
|
||||
errs2 = []
|
||||
errs3 = []
|
||||
relerrs1 = []
|
||||
relerrs2 = []
|
||||
relerrs3 = []
|
||||
max_errs1 = []
|
||||
max_errs2 = []
|
||||
max_errs3 = []
|
||||
|
||||
|
||||
for i in range(100):
|
||||
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
|
||||
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
|
||||
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
|
||||
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
|
||||
A = torch.randn(1, dim, dtype=dtype, device='cuda')
|
||||
#B = torch.randn(4, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||
B = torch.randn(dim*4, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||
|
||||
#print('')
|
||||
#print(A)
|
||||
#print(B.t())
|
||||
#A[:, :-1] = 0
|
||||
#B[:, :-1] = 0
|
||||
#A.flatten()[:-1] = 0
|
||||
#B.flatten()[:-1] = 0
|
||||
if kind == 'fc1':
|
||||
A = torch.randn(1, dim, dtype=dtype, device='cuda')
|
||||
B = torch.randn(dim*4, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||
elif kind == 'fc2':
|
||||
A = torch.randn(1, 4*dim, dtype=dtype, device='cuda')
|
||||
B = torch.randn(dim, 4*dim, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||
elif kind == 'attn':
|
||||
A = torch.randn(1, dim, dtype=dtype, device='cuda')
|
||||
B = torch.randn(dim, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||
elif kind == 'attn_packed':
|
||||
A = torch.randn(1, dim, dtype=dtype, device='cuda')
|
||||
B = torch.randn(dim*3, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||
|
||||
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
|
||||
#F.dequantize_4bit(qB, state)
|
||||
|
||||
C3 = torch.matmul(A, B.t())
|
||||
C2 = F.gemv_4bit(A, qB.t(), state=state)
|
||||
A.requires_grad = True
|
||||
C1 = bnb.matmul_4bit(A, qB.t(), state)
|
||||
|
||||
#print(state)
|
||||
#print(qB)
|
||||
err1 = (C1-C2).abs().float()
|
||||
err2 = (C3-C2).abs().float()
|
||||
err3 = (C3-C1).abs().float()
|
||||
|
||||
#print('')
|
||||
#print(A)
|
||||
#print(B)
|
||||
#print('='*89)
|
||||
#print(C3)
|
||||
mag1 = torch.abs(C1).float()+1e-5
|
||||
mag2 = torch.abs(C3).float()+1e-5
|
||||
mag3 = torch.abs(C3).float()+1e-5
|
||||
|
||||
#print(C1.shape, C2.shape)
|
||||
relerr1 = err1/mag1
|
||||
relerr2 = err2/mag2
|
||||
relerr3 = err3/mag3
|
||||
|
||||
# tensor cores are non-deterministic
|
||||
# so we need to analyze errors around the mean
|
||||
# to test our implementation
|
||||
err = torch.abs(C1-C2).float()
|
||||
mag = torch.abs(C1).float()+1e-5
|
||||
relerr = err/mag
|
||||
max_err = max(err.max(), max_err)
|
||||
max_relerr = max(relerr.max(), max_relerr)
|
||||
err = err.mean().item()
|
||||
relerr = relerr.mean().item()
|
||||
#print(err)
|
||||
max_err1 = err1.max()
|
||||
max_err2 = err2.max()
|
||||
max_err3 = err3.max()
|
||||
|
||||
errs.append(err)
|
||||
relerrs.append(relerr)
|
||||
errs1.append(err1.mean().item())
|
||||
errs2.append(err2.mean().item())
|
||||
errs3.append(err3.mean().item())
|
||||
|
||||
relerrs1.append(relerr1.mean().item())
|
||||
relerrs2.append(relerr2.mean().item())
|
||||
relerrs3.append(relerr3.mean().item())
|
||||
|
||||
max_errs1.append(max_err1.item())
|
||||
max_errs2.append(max_err2.item())
|
||||
max_errs3.append(max_err3.item())
|
||||
|
||||
c = int(C1.numel()*0.0014*(dim/256))+1
|
||||
|
||||
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
|
||||
#print('')
|
||||
#print(dim, sum(errs)/len(errs)/math.sqrt(dim))
|
||||
#print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
|
||||
#print(dim, (max_err.item(), max_relerr.item()))
|
||||
print(C1.flatten()[-20:])
|
||||
print(C2.flatten()[-20:])
|
||||
#print(C1.flatten())
|
||||
#print(C2.flatten())
|
||||
#print(C3.flatten()[-20:])
|
||||
print(sum(errs)/len(errs)/math.sqrt(dim) , dim)
|
||||
print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , dim)
|
||||
err1 = sum(errs1)/len(errs1)/math.sqrt(dim)
|
||||
err2 = sum(errs2)/len(errs2)/math.sqrt(dim)
|
||||
err3 = sum(errs3)/len(errs3)/math.sqrt(dim)
|
||||
relerr1 = sum(relerrs1)/len(relerrs1)/math.sqrt(dim)
|
||||
relerr2 = sum(relerrs2)/len(relerrs2)/math.sqrt(dim)
|
||||
relerr3 = sum(relerrs3)/len(relerrs3)/math.sqrt(dim)
|
||||
maxerr1 = sum(max_errs1)/len(max_errs1)/math.sqrt(dim)
|
||||
maxerr2 = sum(max_errs2)/len(max_errs2)/math.sqrt(dim)
|
||||
maxerr3 = sum(max_errs3)/len(max_errs3)/math.sqrt(dim)
|
||||
absratio = err2/err3
|
||||
relratio = relerr2/relerr3
|
||||
maxratio = relerr2/relerr3
|
||||
|
||||
# for debugging if the tests fails
|
||||
#
|
||||
#print('='*80)
|
||||
#print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
|
||||
#print(C1.flatten()[-20:])
|
||||
#print(C2.flatten()[-20:])
|
||||
#print(f'inference vs training abs: {err1}')
|
||||
#print(f'inference vs training rel: {relerr1}')
|
||||
#print(f'inference vs training max: {maxerr1}')
|
||||
#print(f'inference vs training vs torch err ratio abs: {absratio}')
|
||||
#print(f'inference vs training vs torch err ratio rel: {relratio}')
|
||||
#print(f'inference vs training vs torch err ratio max: {maxratio}')
|
||||
if dtype == torch.float16:
|
||||
assert sum(errs)/len(errs)/math.sqrt(dim) < 5e-5
|
||||
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.0005
|
||||
if dim <= 512:
|
||||
assert err1 < 7e-5
|
||||
assert relerr1 < 0.0008
|
||||
else:
|
||||
assert err1 < 6e-5
|
||||
assert relerr1 < 2e-4
|
||||
assert absratio < 1.005 and absratio > 0.995
|
||||
assert relratio < 1.005 and relratio > 0.995
|
||||
assert maxratio < 1.005 and maxratio > 0.995
|
||||
elif dtype == torch.float32:
|
||||
assert sum(errs)/len(errs)/math.sqrt(dim) < 5e-8
|
||||
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 1e-7
|
||||
if dim <= 512:
|
||||
assert err1 < 5e-8
|
||||
assert relerr1 < 1e-6
|
||||
assert maxerr1 < 1e-7
|
||||
else:
|
||||
assert err1 < 5e-8
|
||||
assert relerr1 < 8e-6
|
||||
assert maxerr1 < 1e-7
|
||||
assert absratio < 1.005 and absratio > 0.995
|
||||
assert relratio < 1.005 and relratio > 0.995
|
||||
assert maxratio < 1.005 and maxratio > 0.995
|
||||
elif dtype == torch.bfloat16:
|
||||
assert sum(errs)/len(errs)/math.sqrt(dim) < 3e-4
|
||||
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.003
|
||||
if dim <= 512:
|
||||
assert err1 < 5e-4
|
||||
assert relerr1 < 0.007
|
||||
assert maxerr1 < 0.015
|
||||
else:
|
||||
assert err1 < 2e-4
|
||||
assert relerr1 < 0.002
|
||||
assert maxerr1 < 0.0012
|
||||
assert absratio < 1.005 and absratio > 0.995
|
||||
assert relratio < 1.04 and relratio > 0.96
|
||||
assert maxratio < 1.02 and maxratio > 0.98
|
||||
|
||||
@pytest.mark.skip("Row scale has some bugs for ampere")
|
||||
def test_managed():
|
||||
|
|
Loading…
Reference in New Issue
Block a user