Added abitrary data types; fixed a bug for small matrices.
This commit is contained in:
parent
eefbf60270
commit
4b88d69de7
|
@ -562,6 +562,6 @@ def matmul(
|
|||
def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None):
|
||||
assert quant_state is not None
|
||||
if A.numel() == A.shape[-1] and A.requires_grad == False:
|
||||
return F.cutlass3_gemm(A, B.t(), out, state=quant_state)
|
||||
return F.gemv_4bit(A, B.t(), out, state=quant_state)
|
||||
else:
|
||||
return MatMul4Bit.apply(A, B, out, bias, quant_state)
|
||||
|
|
|
@ -240,17 +240,19 @@ def create_normal_map(offset=0.9677083, use_extra_value=True):
|
|||
v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist()
|
||||
v2 = [0]*(256-15) ## we have 15 non-zero values in this data type
|
||||
v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist()
|
||||
v = v1 + v2 + v3
|
||||
else:
|
||||
v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist()
|
||||
v2 = [0]*(256-14) ## we have 14 non-zero values in this data type
|
||||
v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist()
|
||||
v = v1 + v2 + v3
|
||||
|
||||
v = v1 + v2 + v3
|
||||
|
||||
values = torch.Tensor(v)
|
||||
values = values.sort().values
|
||||
values /= values.max()
|
||||
|
||||
assert values.numel() == 256
|
||||
|
||||
return values
|
||||
|
||||
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8):
|
||||
|
@ -710,6 +712,47 @@ def dequantize_blockwise(
|
|||
|
||||
return out
|
||||
|
||||
def get_4bit_type(typename, device=None, blocksize=64):
|
||||
if device is None: device = 'cuda'
|
||||
data = None
|
||||
if typename == 'nf4':
|
||||
data = [-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635,
|
||||
-0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725,
|
||||
0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941,
|
||||
0.7229568362236023, 1.0]
|
||||
elif typename == 'fp4':
|
||||
# 0b000 = 0
|
||||
# 0b001 = 0.0625
|
||||
# 0b010 = 8
|
||||
# 0b011 = 12
|
||||
# 0b100 = 4
|
||||
# 0b101 = 6
|
||||
# 0b110 = 2
|
||||
# 0b111 = 3
|
||||
data = [0, 0.0625, 8.0, 12.0, 4.0, 6.0, 2.0, 3.0, -0, -0.0625, -8.0, -12.0, -4.0, -6.0, -2.0, -3.0]
|
||||
elif typename == 'int4':
|
||||
data = [7, 6, 5, 4, 3, 2, 1, 0, -0, -1, -2, -3, -4, -5, -6, -7]
|
||||
elif typename == 'af4':
|
||||
# Taken from: NF4 Isn't Information Theoretically Optimal (and that's Good)
|
||||
# https://arxiv.org/abs/2306.06965
|
||||
if blocksize == 64:
|
||||
data = [-1., -0.69441008, -0.51243739, -0.3736951, -0.25607552, -0.14982478,
|
||||
-0.04934812, 0., 0.04273164, 0.12934483, 0.21961274, 0.31675666,
|
||||
0.42563882, 0.55496234, 0.72424863, 1.][::-1]
|
||||
else:
|
||||
raise NotImplementedError(f'4-bit AbnormalFloats currently only support blocksize 64.')
|
||||
|
||||
if data is None:
|
||||
raise NotImplementedError(f'Typename {typename} not supported')
|
||||
|
||||
data = Tensor(data)
|
||||
data /= data.abs().max()
|
||||
assert data.numel() == 16
|
||||
|
||||
return data.to(device)
|
||||
|
||||
|
||||
|
||||
def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False):
|
||||
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4')
|
||||
|
||||
|
@ -783,6 +826,8 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
|
|||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
post_call(A.device)
|
||||
|
||||
datatype = get_4bit_type(quant_type, device=A.device)
|
||||
|
||||
if compress_statistics:
|
||||
offset = absmax.mean()
|
||||
absmax -= offset
|
||||
|
@ -790,9 +835,9 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
|
|||
#qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256)
|
||||
qabsmax, state2 = quantize_blockwise(absmax, blocksize=256)
|
||||
del absmax
|
||||
state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type]
|
||||
state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type, datatype]
|
||||
else:
|
||||
state = [absmax, input_shape, A.dtype, blocksize, None, quant_type]
|
||||
state = [absmax, input_shape, A.dtype, blocksize, None, quant_type, datatype]
|
||||
|
||||
return out, state
|
||||
|
||||
|
@ -839,7 +884,7 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
|
|||
shape = out.shape
|
||||
dtype = out.dtype
|
||||
else:
|
||||
absmax, shape, dtype, blocksize, compressed_stats, quant_type = quant_state
|
||||
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state
|
||||
|
||||
|
||||
if compressed_stats is not None:
|
||||
|
@ -1408,13 +1453,14 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
|
|||
|
||||
return sout
|
||||
|
||||
def cutlass3_gemm(
|
||||
def gemv_4bit(
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
out: Tensor = None,
|
||||
transposed_A=False,
|
||||
transposed_B=False,
|
||||
state=None
|
||||
state=None,
|
||||
storage_type='nf4'
|
||||
):
|
||||
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
|
||||
if state is None:
|
||||
|
@ -1491,8 +1537,6 @@ def cutlass3_gemm(
|
|||
ldb = sA[2]
|
||||
ldc = m
|
||||
|
||||
ptr = CUBLAS_Context.get_instance().get_context(A.device)
|
||||
|
||||
# B^T @ A^T = C^T
|
||||
# [km, nk -> mn]
|
||||
#lda = ldb = ldc = 1
|
||||
|
@ -1514,15 +1558,11 @@ def cutlass3_gemm(
|
|||
|
||||
if B.dtype == torch.uint8:
|
||||
if A.dtype == torch.float16:
|
||||
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
||||
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
||||
elif A.dtype == torch.bfloat16:
|
||||
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
||||
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
||||
else:
|
||||
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
|
||||
elif A.dtype == torch.float32:
|
||||
lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc)
|
||||
elif A.dtype == torch.float16:
|
||||
lib.cgemm_host_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc)
|
||||
else:
|
||||
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
|
||||
|
||||
|
|
|
@ -3520,7 +3520,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
|
|||
}
|
||||
|
||||
#define num_values_4bit 32
|
||||
template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
|
||||
template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{
|
||||
|
||||
// per threadblock:
|
||||
|
@ -3568,7 +3568,9 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
|
|||
{
|
||||
#pragma unroll
|
||||
for(int j = 0; j < (num_values_8bit); j++)
|
||||
if((inner_idx/2) + j < K)
|
||||
if((inner_idx_halved) + j < K)
|
||||
local_B_4bit[j] = B[offset_B+inner_idx_halved + j];
|
||||
else
|
||||
local_B_4bit[j] = 0b01110111;
|
||||
}
|
||||
}
|
||||
|
@ -3578,6 +3580,9 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
|
|||
{
|
||||
local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax;
|
||||
local_B[k*2 + 1] = quant_map[local_B_4bit[k] & 0x0F]*local_absmax;
|
||||
|
||||
//if(threadIdx.x == 0)
|
||||
//printf("%f %f %f %f\n", (float)local_B[k*2], (float)dDequantizeNF4(local_B_4bit[k] >> 4)*(float)local_absmax, (float)local_B[k*2]- ((float)dDequantizeNF4(local_B_4bit[k] >> 4)*(float)local_absmax), (float)local_absmax);
|
||||
}
|
||||
|
||||
if(inner_idx+num_values_4bit)
|
||||
|
@ -3773,8 +3778,8 @@ template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, ha
|
|||
template __global__ void kgemm_4bit_inference<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference<half, 256>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
template __global__ void kgemm_4bit_inference_naive<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference_naive<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
||||
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
||||
|
|
|
@ -125,7 +125,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
|
|||
|
||||
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc);
|
||||
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
template <typename T, int FUNC> __global__ void kfunc(T *A, T *B, T value, long n);
|
||||
|
||||
|
|
|
@ -729,12 +729,12 @@ template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsi
|
|||
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
}
|
||||
|
||||
template <typename T> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
|
||||
template <typename T> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{
|
||||
|
||||
int num_blocks = (m+3)/4;
|
||||
|
||||
kgemm_4bit_inference_naive<T, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
kgemm_4bit_inference_naive<T, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
|
||||
}
|
||||
|
||||
template <typename T, int FUNC> void func(T *A, T *B, T value, long n)
|
||||
|
@ -757,8 +757,8 @@ template void func<float, ARANGE>(float *A, float *B, float value, long n);
|
|||
template void func<float, _MUL>(float *A, float *B, float value, long n);
|
||||
|
||||
template void gemm_4bit_inference<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template void gemm_4bit_inference_naive<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template void gemm_4bit_inference_naive<__nv_bfloat16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template void gemm_4bit_inference_naive<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template void gemm_4bit_inference_naive<__nv_bfloat16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
|
||||
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
|
||||
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
|
||||
template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
|
||||
|
|
|
@ -200,7 +200,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows
|
|||
|
||||
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits);
|
||||
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template <typename T> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template <typename T> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
template <typename T, int FUNC> void func(T *A, T *B, T value, long n);
|
||||
|
||||
|
|
|
@ -28,11 +28,11 @@ void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int l
|
|||
void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||
void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive<half>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive<__nv_bfloat16>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||
void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive<__nv_bfloat16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
|
||||
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \
|
||||
|
@ -394,11 +394,11 @@ extern "C"
|
|||
CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE)
|
||||
CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)
|
||||
|
||||
void cgemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||
void cgemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||
void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
#endif
|
||||
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
|
||||
|
|
|
@ -1816,7 +1816,7 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
|
||||
linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
|
||||
linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
|
||||
F.cutlass3_gemm(A, B_nf4.t(), state=state_nf4)
|
||||
F.gemv_4bit(A, B_nf4.t(), state=state_nf4)
|
||||
|
||||
# warmup
|
||||
for i in range(iters):
|
||||
|
@ -1849,7 +1849,7 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
#bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
|
||||
F.cutlass3_gemm(A, B_nf4.t(), state=state_nf4)
|
||||
F.gemv_4bit(A, B_nf4.t(), state=state_nf4)
|
||||
torch.cuda.synchronize()
|
||||
print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
|
||||
|
@ -2351,76 +2351,14 @@ def test_normal_map_tree():
|
|||
print(pivots)
|
||||
|
||||
|
||||
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
||||
def test_cutlass3_gemm(dtype):
|
||||
debug = True
|
||||
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
||||
#for dim in [4096, 5120, 6656, 8192]:
|
||||
for dim in [4096]:
|
||||
#for dim in [128+1]:
|
||||
errs = []
|
||||
relerrs = []
|
||||
max_err = 0
|
||||
max_relerr = 0
|
||||
for i in range(100):
|
||||
A = torch.randn(1, dim, dtype=dtype, device='cuda')
|
||||
B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||
#B = torch.randn(1, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||
|
||||
#print('')
|
||||
#print(A)
|
||||
#print(B.t())
|
||||
#A[:, :-1] = 0
|
||||
#B[:, :-1] = 0
|
||||
|
||||
|
||||
C1 = torch.matmul(A, B.t())
|
||||
C2 = F.cutlass3_gemm(A, B.t())
|
||||
|
||||
# tensor cores are non-deterministic
|
||||
# so we need to analyze errors around the mean
|
||||
# to test our implementation
|
||||
err = torch.abs(C1-C2)
|
||||
mag = torch.abs(C1)+1e-8
|
||||
relerr = err/mag
|
||||
max_err = max(err.max(), max_err)
|
||||
max_relerr = max(relerr.max(), max_relerr)
|
||||
err = err.mean().item()
|
||||
relerr = relerr.mean().item()
|
||||
|
||||
errs.append(err)
|
||||
relerrs.append(relerr)
|
||||
|
||||
#if not debug and err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
|
||||
# print('')
|
||||
# print(i, err, relerr)
|
||||
# print(A.flatten()[-6:])
|
||||
# print(B.flatten()[-6:])
|
||||
# out = A.flatten()[-6:]*B.flatten()[-6:]
|
||||
# print(out)
|
||||
# print(out[:-1].sum())
|
||||
# print('='*80)
|
||||
# print(C1.flatten()[-6:])
|
||||
# print(C2.flatten()[-6:])
|
||||
# #assert False, 'ERROR'
|
||||
|
||||
c = int(C1.numel()*0.0014*(dim/256))+1
|
||||
|
||||
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=not debug)
|
||||
#print(c/math.sqrt(dim))
|
||||
print('')
|
||||
print(dim, sum(errs)/len(errs)/math.sqrt(dim))
|
||||
print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
|
||||
print(dim, (max_err.item(), max_relerr.item()))
|
||||
|
||||
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=['fp16', 'bf16'])
|
||||
#@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16'])
|
||||
def test_gemm_4bit(dtype):
|
||||
def test_gemv_4bit(dtype):
|
||||
print('')
|
||||
#for dim in [64, 128, 256, 512, 1024, 2048, 4096]:
|
||||
for dim in [4*1024]:
|
||||
for dim in [64, 128, 256, 512, 1024, 2048, 4096]:
|
||||
#for dim in [4*1024]:
|
||||
#for dim in [1*16]:
|
||||
errs = []
|
||||
relerrs = []
|
||||
max_err = 0
|
||||
|
@ -2446,9 +2384,10 @@ def test_gemm_4bit(dtype):
|
|||
qB, state = F.quantize_nf4(B)
|
||||
F.dequantize_nf4(qB, state)
|
||||
|
||||
#C2 = bnb.matmul_4bit(A, qB.t(), state)
|
||||
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
|
||||
C1 = torch.matmul(A, B.t())
|
||||
C2 = F.gemv_4bit(A, qB.t(), state=state)
|
||||
C3 = torch.matmul(A, B.t())
|
||||
A.requires_grad = True
|
||||
C1 = bnb.matmul_4bit(A, qB.t(), state)
|
||||
|
||||
#print(state)
|
||||
#print(qB)
|
||||
|
@ -2457,8 +2396,7 @@ def test_gemm_4bit(dtype):
|
|||
#print(A)
|
||||
#print(B)
|
||||
#print('='*89)
|
||||
#print(C1)
|
||||
#print(C2)
|
||||
#print(C3.flatten()[-20:])
|
||||
#print(C3)
|
||||
|
||||
#print(C1.shape, C2.shape)
|
||||
|
@ -2485,10 +2423,16 @@ def test_gemm_4bit(dtype):
|
|||
#print(dim, sum(errs)/len(errs)/math.sqrt(dim))
|
||||
#print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
|
||||
#print(dim, (max_err.item(), max_relerr.item()))
|
||||
print(C1.flatten()[-20:])
|
||||
print(C2.flatten()[-20:])
|
||||
print(sum(errs)/len(errs)/math.sqrt(dim) , 0.00015)
|
||||
print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , 0.0015)
|
||||
assert sum(errs)/len(errs)/math.sqrt(dim) < 0.011
|
||||
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.15
|
||||
if dtype == torch.float16:
|
||||
assert sum(errs)/len(errs)/math.sqrt(dim) < 5e-5
|
||||
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.0005
|
||||
else:
|
||||
assert sum(errs)/len(errs)/math.sqrt(dim) < 3e-4
|
||||
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.003
|
||||
|
||||
@pytest.mark.skip("Row scale has some bugs for ampere")
|
||||
def test_managed():
|
||||
|
|
Loading…
Reference in New Issue
Block a user