Added more extensive gemv tests; blocksize guard for gemv.

This commit is contained in:
Tim Dettmers 2023-07-11 05:55:49 -07:00
parent b8da4a165a
commit ba51d95d43
5 changed files with 122 additions and 69 deletions

View File

@ -3,6 +3,7 @@ import warnings
from dataclasses import dataclass
from functools import reduce # Required in Python 3
from typing import Tuple, Optional, List
from warnings import warn
import torch
@ -565,6 +566,11 @@ def matmul(
def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None):
assert quant_state is not None
if A.numel() == A.shape[-1] and A.requires_grad == False:
return F.gemv_4bit(A, B.t(), out, state=quant_state)
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state
if A.shape[-1] % blocksize != 0:
warn(f'Some matrices hidden dimension is not a multiple of {blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}')
return MatMul4Bit.apply(A, B, out, bias, quant_state)
else:
return F.gemv_4bit(A, B.t(), out, state=quant_state)
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)

View File

@ -1504,6 +1504,7 @@ def gemv_4bit(
lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
else:
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
else:
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')

View File

@ -222,6 +222,7 @@ __device__ half dhDequantizeNF4(unsigned char val)
__device__ float dDequantizeNF4(unsigned char val)
{
// the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py
if((val & 0b1000) == 8)
@ -3526,10 +3527,9 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
{
// per threadblock:
// load step-by-step in chunks of [64,warps]: 1x64 * [64,warps] -> [1,warps]
// 64 packed 8bit values per warp AND each warp loads one output for B -> 1x128 * 128x1
// load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps]
// 4 warps -> 4 loads per iter
// 1x128 * 128x4 -> 1x4 outputs
// 1x32 * 32x4 -> 1x4 outputs per thread block
typedef cub::WarpReduce<float> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32];
@ -3547,7 +3547,6 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
for(int i = threadIdx.x; i < 16; i++)
quant_map[i] = T(datatype[i]);
__syncthreads();
// A: [1, K]
@ -3563,6 +3562,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
{
if((inner_idx_halved + num_values_8bit) < (K/2))
{
// this is the most important for performance considerations
reinterpret_cast<int4(&)[num_values_8bit]>(local_B_4bit)[0] = reinterpret_cast<int4*>(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)];
}
else
@ -3597,6 +3597,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
if(inner_idx+num_values_4bit < K)
{
// this is also relatively important for performance
if(BITS==16)
{
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/4) + 0];
@ -3618,6 +3619,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
}
else
#pragma unroll
for(int k = 0; k < num_values_4bit; k++)
if(inner_idx + k < K)
local_A[k] = A[inner_idx + k];
@ -3625,6 +3627,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
local_A[k] = T(0.0f);
// accumulate in float; small performance hit for Ampere, but lower error for outputs
#pragma unroll
for(int k = 0; k < num_values_4bit; k++)
{

View File

@ -735,6 +735,7 @@ template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int
int num_blocks = (m+3)/4;
kgemm_4bit_inference_naive<T, 128, BITS><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template <typename T, int FUNC> void func(T *A, T *B, T value, long n)

View File

@ -2262,7 +2262,7 @@ def test_fp4_quant(dtype):
A2 = F.dequantize_fp4(qa, SA)
err = (A1 - A2).abs().float()
relerr = (err/A1.abs().float()).mean()
relerr = (err/(A1.abs().float()+1e-8)).mean()
idx = err > 1.0
err = err.mean()
@ -2361,91 +2361,133 @@ def test_normal_map_tree():
@pytest.mark.parametrize("double_quant", [True, False], ids=['DQ_True', 'DQ_False'])
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
@pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'], ids=['fc1', 'fc2', 'attn', 'attn_packed'])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
def test_gemv_4bit(dtype, storage_type, double_quant):
print('')
for dim in [128, 256, 512, 1024, 2048, 4096]:
def test_gemv_4bit(dtype, storage_type, double_quant, kind):
for dim in [128, 256, 512, 1024, 2048, 4096, 6144]:
#for dim in [4*1024]:
#for dim in [1*16]:
errs = []
relerrs = []
max_err = 0
max_relerr = 0
#for dim in [1*128]:
errs1 = []
errs2 = []
errs3 = []
relerrs1 = []
relerrs2 = []
relerrs3 = []
max_errs1 = []
max_errs2 = []
max_errs3 = []
for i in range(100):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
A = torch.randn(1, dim, dtype=dtype, device='cuda')
#B = torch.randn(4, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
B = torch.randn(dim*4, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
#print('')
#print(A)
#print(B.t())
#A[:, :-1] = 0
#B[:, :-1] = 0
#A.flatten()[:-1] = 0
#B.flatten()[:-1] = 0
if kind == 'fc1':
A = torch.randn(1, dim, dtype=dtype, device='cuda')
B = torch.randn(dim*4, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
elif kind == 'fc2':
A = torch.randn(1, 4*dim, dtype=dtype, device='cuda')
B = torch.randn(dim, 4*dim, dtype=dtype, device='cuda')/math.sqrt(dim)
elif kind == 'attn':
A = torch.randn(1, dim, dtype=dtype, device='cuda')
B = torch.randn(dim, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
elif kind == 'attn_packed':
A = torch.randn(1, dim, dtype=dtype, device='cuda')
B = torch.randn(dim*3, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
#F.dequantize_4bit(qB, state)
C3 = torch.matmul(A, B.t())
C2 = F.gemv_4bit(A, qB.t(), state=state)
A.requires_grad = True
C1 = bnb.matmul_4bit(A, qB.t(), state)
#print(state)
#print(qB)
err1 = (C1-C2).abs().float()
err2 = (C3-C2).abs().float()
err3 = (C3-C1).abs().float()
#print('')
#print(A)
#print(B)
#print('='*89)
#print(C3)
mag1 = torch.abs(C1).float()+1e-5
mag2 = torch.abs(C3).float()+1e-5
mag3 = torch.abs(C3).float()+1e-5
#print(C1.shape, C2.shape)
relerr1 = err1/mag1
relerr2 = err2/mag2
relerr3 = err3/mag3
# tensor cores are non-deterministic
# so we need to analyze errors around the mean
# to test our implementation
err = torch.abs(C1-C2).float()
mag = torch.abs(C1).float()+1e-5
relerr = err/mag
max_err = max(err.max(), max_err)
max_relerr = max(relerr.max(), max_relerr)
err = err.mean().item()
relerr = relerr.mean().item()
#print(err)
max_err1 = err1.max()
max_err2 = err2.max()
max_err3 = err3.max()
errs.append(err)
relerrs.append(relerr)
errs1.append(err1.mean().item())
errs2.append(err2.mean().item())
errs3.append(err3.mean().item())
relerrs1.append(relerr1.mean().item())
relerrs2.append(relerr2.mean().item())
relerrs3.append(relerr3.mean().item())
max_errs1.append(max_err1.item())
max_errs2.append(max_err2.item())
max_errs3.append(max_err3.item())
c = int(C1.numel()*0.0014*(dim/256))+1
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
#print('')
#print(dim, sum(errs)/len(errs)/math.sqrt(dim))
#print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
#print(dim, (max_err.item(), max_relerr.item()))
print(C1.flatten()[-20:])
print(C2.flatten()[-20:])
#print(C1.flatten())
#print(C2.flatten())
#print(C3.flatten()[-20:])
print(sum(errs)/len(errs)/math.sqrt(dim) , dim)
print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , dim)
err1 = sum(errs1)/len(errs1)/math.sqrt(dim)
err2 = sum(errs2)/len(errs2)/math.sqrt(dim)
err3 = sum(errs3)/len(errs3)/math.sqrt(dim)
relerr1 = sum(relerrs1)/len(relerrs1)/math.sqrt(dim)
relerr2 = sum(relerrs2)/len(relerrs2)/math.sqrt(dim)
relerr3 = sum(relerrs3)/len(relerrs3)/math.sqrt(dim)
maxerr1 = sum(max_errs1)/len(max_errs1)/math.sqrt(dim)
maxerr2 = sum(max_errs2)/len(max_errs2)/math.sqrt(dim)
maxerr3 = sum(max_errs3)/len(max_errs3)/math.sqrt(dim)
absratio = err2/err3
relratio = relerr2/relerr3
maxratio = relerr2/relerr3
# for debugging if the tests fails
#
#print('='*80)
#print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
#print(C1.flatten()[-20:])
#print(C2.flatten()[-20:])
#print(f'inference vs training abs: {err1}')
#print(f'inference vs training rel: {relerr1}')
#print(f'inference vs training max: {maxerr1}')
#print(f'inference vs training vs torch err ratio abs: {absratio}')
#print(f'inference vs training vs torch err ratio rel: {relratio}')
#print(f'inference vs training vs torch err ratio max: {maxratio}')
if dtype == torch.float16:
assert sum(errs)/len(errs)/math.sqrt(dim) < 5e-5
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.0005
if dim <= 512:
assert err1 < 7e-5
assert relerr1 < 0.0008
else:
assert err1 < 6e-5
assert relerr1 < 2e-4
assert absratio < 1.005 and absratio > 0.995
assert relratio < 1.005 and relratio > 0.995
assert maxratio < 1.005 and maxratio > 0.995
elif dtype == torch.float32:
assert sum(errs)/len(errs)/math.sqrt(dim) < 5e-8
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 1e-7
if dim <= 512:
assert err1 < 5e-8
assert relerr1 < 1e-6
assert maxerr1 < 1e-7
else:
assert err1 < 5e-8
assert relerr1 < 8e-6
assert maxerr1 < 1e-7
assert absratio < 1.005 and absratio > 0.995
assert relratio < 1.005 and relratio > 0.995
assert maxratio < 1.005 and maxratio > 0.995
elif dtype == torch.bfloat16:
assert sum(errs)/len(errs)/math.sqrt(dim) < 3e-4
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.003
if dim <= 512:
assert err1 < 5e-4
assert relerr1 < 0.007
assert maxerr1 < 0.015
else:
assert err1 < 2e-4
assert relerr1 < 0.002
assert maxerr1 < 0.0012
assert absratio < 1.005 and absratio > 0.995
assert relratio < 1.04 and relratio > 0.96
assert maxratio < 1.02 and maxratio > 0.98
@pytest.mark.skip("Row scale has some bugs for ampere")
def test_managed():