Added bfloat16 quantizations and tests.
This commit is contained in:
parent
dfe6900b94
commit
02fd80cb81
|
@ -561,4 +561,7 @@ def matmul(
|
|||
|
||||
def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None):
|
||||
assert quant_state is not None
|
||||
return MatMul4Bit.apply(A, B, out, bias, quant_state)
|
||||
if A.numel() == A.shape[-1] and A.requires_grad == False:
|
||||
return F.cutlass3_gemm(A, B.t(), out, state=quant_state)
|
||||
else:
|
||||
return MatMul4Bit.apply(A, B, out, bias, quant_state)
|
||||
|
|
|
@ -617,6 +617,8 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou
|
|||
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||
elif A.dtype == torch.float16:
|
||||
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||
elif A.dtype == torch.bfloat16:
|
||||
lib.cquantize_blockwise_bf16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||
else:
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
post_call(A.device)
|
||||
|
@ -629,11 +631,9 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou
|
|||
offset = absmax.mean()
|
||||
absmax -= offset
|
||||
qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False)
|
||||
state = [qabsmax, code, blocksize, nested, offset, state2]
|
||||
state = [qabsmax, code, blocksize, nested, A.dtype, offset, state2]
|
||||
else:
|
||||
state = [absmax, code, blocksize, nested, None, None]
|
||||
|
||||
|
||||
state = [absmax, code, blocksize, nested, A.dtype, None, None]
|
||||
|
||||
return out, state
|
||||
|
||||
|
@ -678,18 +678,16 @@ def dequantize_blockwise(
|
|||
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
|
||||
code = name2qmap["dynamic"]
|
||||
|
||||
if out is None:
|
||||
out = torch.zeros_like(A, dtype=torch.float32)
|
||||
|
||||
if quant_state is None:
|
||||
quant_state = (absmax, code, blocksize)
|
||||
assert absmax is not None and out is not None
|
||||
else:
|
||||
absmax, code, blocksize, nested, offset, state2 = quant_state
|
||||
if nested:
|
||||
absmax = dequantize_blockwise(absmax, state2)
|
||||
absmax += offset
|
||||
quant_state = (absmax, code, blocksize, False, torch.float32, None, None)
|
||||
|
||||
absmax, code, blocksize, nested, dtype, offset, state2 = quant_state
|
||||
if nested:
|
||||
absmax = dequantize_blockwise(absmax, state2)
|
||||
absmax += offset
|
||||
|
||||
if out is None:
|
||||
out = torch.empty(A.shape, dtype=dtype, device=A.device)
|
||||
|
||||
if A.device.type != 'cpu':
|
||||
device = pre_call(A.device)
|
||||
|
@ -701,6 +699,8 @@ def dequantize_blockwise(
|
|||
lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
||||
elif out.dtype == torch.float16:
|
||||
lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
||||
elif out.dtype == torch.bfloat16:
|
||||
lib.cdequantize_blockwise_bf16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
||||
else:
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
post_call(A.device)
|
||||
|
@ -774,6 +774,11 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
|
|||
lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
||||
else:
|
||||
lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
||||
elif A.dtype == torch.bfloat16:
|
||||
if quant_type == 'fp4':
|
||||
lib.cquantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
||||
else:
|
||||
lib.cquantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
||||
else:
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
post_call(A.device)
|
||||
|
@ -860,6 +865,11 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
|
|||
lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
||||
else:
|
||||
lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
||||
elif out.dtype == torch.bfloat16:
|
||||
if quant_type == 'fp4':
|
||||
lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
||||
else:
|
||||
lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
||||
else:
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
post_call(A.device)
|
||||
|
@ -1503,7 +1513,12 @@ def cutlass3_gemm(
|
|||
ldc = ct.c_int32(ldc)
|
||||
|
||||
if B.dtype == torch.uint8:
|
||||
lib.cgemm_4bit_inference_naive(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
||||
if A.dtype == torch.float16:
|
||||
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
||||
elif A.dtype == torch.bfloat16:
|
||||
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
||||
else:
|
||||
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
|
||||
elif A.dtype == torch.float32:
|
||||
lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc)
|
||||
elif A.dtype == torch.float16:
|
||||
|
@ -1515,7 +1530,6 @@ def cutlass3_gemm(
|
|||
|
||||
|
||||
|
||||
|
||||
def igemm(
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
|
|
|
@ -3540,7 +3540,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
|
|||
unsigned char local_B_4bit[num_values_8bit];
|
||||
T local_B[num_values_4bit];
|
||||
T local_A[num_values_4bit];
|
||||
__shared__ half quant_map[16*THREADS];
|
||||
__shared__ T quant_map[16*THREADS];
|
||||
|
||||
for(int i = 0; i < 16; i++)
|
||||
quant_map[threadIdx.x + (i*blockDim.x)] = nf4_data[i];
|
||||
|
@ -3769,11 +3769,8 @@ template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, ha
|
|||
template __global__ void kgemm_4bit_inference<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference<half, 256>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
template __global__ void kgemm_4bit_inference_naive<half, 32>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference_naive<half, 96>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference_naive<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference_naive<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference_naive<half, 256>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
||||
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
||||
|
@ -3929,6 +3926,20 @@ MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit)
|
|||
MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit)
|
||||
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit)
|
||||
|
@ -3937,13 +3948,6 @@ MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit)
|
|||
MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4)
|
||||
|
@ -3951,13 +3955,6 @@ MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4)
|
|||
MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4)
|
||||
|
@ -3966,12 +3963,38 @@ MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4)
|
|||
MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4)
|
||||
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 1, General8bit)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, NF4)
|
||||
|
||||
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n);
|
||||
|
||||
#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \
|
||||
template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \
|
||||
|
|
33
csrc/ops.cu
33
csrc/ops.cu
|
@ -733,20 +733,8 @@ template <typename T> void gemm_4bit_inference_naive(int m, int n, int k, T * A,
|
|||
{
|
||||
|
||||
int num_blocks = (m+3)/4;
|
||||
//int num_blocks = m;
|
||||
|
||||
cout << num_blocks << endl;
|
||||
//cout << lda << endl;
|
||||
//cout << ldb << endl;
|
||||
//cout << ldc << endl;
|
||||
|
||||
//cout << m << endl;
|
||||
//cout << n << endl;
|
||||
//cout << k << endl;
|
||||
kgemm_4bit_inference_naive<T, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
//kgemm_4bit_inference<T, 256><<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
//kgemm_4bit_inference<T, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
}
|
||||
|
||||
template <typename T, int FUNC> void func(T *A, T *B, T value, long n)
|
||||
|
@ -770,6 +758,7 @@ template void func<float, _MUL>(float *A, float *B, float value, long n);
|
|||
|
||||
template void gemm_4bit_inference<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template void gemm_4bit_inference_naive<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template void gemm_4bit_inference_naive<__nv_bfloat16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
|
||||
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
|
||||
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
|
||||
template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
|
||||
|
@ -796,19 +785,27 @@ template void estimateQuantiles(half *A, float *code, float offset, int n);
|
|||
template void estimateQuantiles(float *A, float *code, float offset, int n);
|
||||
|
||||
template void quantizeBlockwise<half, 1, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 1, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<half, 0, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 0, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<half, 0, FP4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 0, FP4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<half, 0, NF4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 1, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 0, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 0, FP4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 0, NF4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half, General8bit>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void quantizeBlockwise<__nv_bfloat16, 1, General8bit>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<__nv_bfloat16, 0, General8bit>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<__nv_bfloat16, 0, FP4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<__nv_bfloat16, 0, NF4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
|
||||
template void dequantizeBlockwise<float, General8bit>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half, FP4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<float, FP4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half, NF4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<float, NF4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half, General8bit>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half, FP4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half, NF4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<__nv_bfloat16, General8bit>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<__nv_bfloat16, FP4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<__nv_bfloat16, NF4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n);
|
||||
|
||||
#define MAKE_optimizer32bit(name, gtype) \
|
||||
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
|
||||
|
|
|
@ -28,9 +28,12 @@ void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int l
|
|||
void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||
void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive<__nv_bfloat16>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
|
||||
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \
|
||||
|
||||
|
@ -106,19 +109,29 @@ void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){
|
|||
void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping<half>(g, gnorm_vec, step, n); }
|
||||
|
||||
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
|
||||
void quantizeBlockwise_bf16(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_bf16_fp4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_bf16_nf4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
|
||||
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
|
||||
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, General8bit>(code, A, absmax, out, blocksize, n); } \
|
||||
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, General8bit>(code, A, absmax, out, blocksize, n); }
|
||||
void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, FP4>(NULL, A, absmax, out, blocksize, n); } \
|
||||
void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, FP4>(NULL, A, absmax, out, blocksize, n); }
|
||||
void dequantizeBlockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, NF4>(NULL, A, absmax, out, blocksize, n); } \
|
||||
|
||||
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, General8bit>(code, A, absmax, out, blocksize, n); }
|
||||
void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, FP4>(NULL, A, absmax, out, blocksize, n); }
|
||||
void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, NF4>(NULL, A, absmax, out, blocksize, n); }
|
||||
|
||||
void dequantizeBlockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise<__nv_bfloat16, General8bit>(code, A, absmax, out, blocksize, n); }
|
||||
void dequantizeBlockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n); }
|
||||
void dequantizeBlockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n); }
|
||||
|
||||
|
||||
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
|
||||
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
|
||||
|
@ -177,21 +190,31 @@ extern "C"
|
|||
void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); }
|
||||
void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); }
|
||||
void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); }
|
||||
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
||||
|
||||
void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
||||
|
||||
void cquantize_blockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); }
|
||||
|
||||
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); }
|
||||
|
||||
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); }
|
||||
|
||||
void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); }
|
||||
|
||||
void cquantize_blockwise_bf16(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_bf16_fp4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_bf16_nf4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); }
|
||||
|
||||
void cdequantize_blockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); }
|
||||
|
||||
#define MAKE_CFUNC32(name, gtype, gbits) \
|
||||
void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \
|
||||
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
|
||||
|
@ -348,9 +371,6 @@ extern "C"
|
|||
void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
void cgemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
void *cget_managed_ptr(size_t bytes)
|
||||
{
|
||||
void *ptr;
|
||||
|
@ -374,6 +394,12 @@ extern "C"
|
|||
CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE)
|
||||
CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)
|
||||
|
||||
void cgemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
#endif
|
||||
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); }
|
||||
|
|
|
@ -154,34 +154,36 @@ def test_dynamic_quantization():
|
|||
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
|
||||
@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"])
|
||||
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
|
||||
def test_dynamic_blockwise_quantization(nested, blocksize):
|
||||
def test_dynamic_blockwise_quantization(dtype, nested, blocksize):
|
||||
#print('')
|
||||
diffs = []
|
||||
reldiffs = []
|
||||
for i in range(100):
|
||||
A1 = torch.randn(1024, 1024, device="cuda")
|
||||
A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
|
||||
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
|
||||
A2 = F.dequantize_blockwise(C, S)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff / torch.abs(A1 + 1e-8)
|
||||
diff = torch.abs(A1 - A2).float()
|
||||
reldiff = diff / torch.abs(A1.float() + 1e-8)
|
||||
diffs.append(diff.mean().item())
|
||||
reldiffs.append(reldiff.mean().item())
|
||||
abserr = sum(diffs)/len(diffs)
|
||||
relerr = sum(reldiffs)/len(reldiffs)
|
||||
#print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
|
||||
#print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
|
||||
assert abserr < 0.011
|
||||
assert relerr < 0.018
|
||||
#print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs))
|
||||
#print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
assert A2.dtype == dtype
|
||||
|
||||
diffs = []
|
||||
for i in range(100):
|
||||
A1 = torch.rand(1024, 1024, device="cuda")
|
||||
A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype)
|
||||
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
|
||||
A2 = F.dequantize_blockwise(C, S)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff / torch.abs(A1 + 1e-8)
|
||||
diff = torch.abs(A1 - A2).float()
|
||||
reldiff = diff / torch.abs(A1.float() + 1e-8)
|
||||
diffs.append(diff.mean().item())
|
||||
reldiffs.append(reldiff.mean().item())
|
||||
#torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
|
||||
|
@ -189,6 +191,7 @@ def test_dynamic_blockwise_quantization(nested, blocksize):
|
|||
relerr = sum(reldiffs)/len(reldiffs)
|
||||
assert abserr < 0.0035
|
||||
assert relerr < 0.015
|
||||
assert A2.dtype == dtype
|
||||
#print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
|
||||
#print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
|
||||
|
@ -1773,8 +1776,8 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
|
|||
print("partial matmul", time.time() - t0)
|
||||
|
||||
|
||||
batch_size = 32
|
||||
seqdim = 512+256
|
||||
batch_size = 1
|
||||
seqdim = 1
|
||||
values = []
|
||||
#values.append((batch_size, seqdim, 768, 4 * 768))
|
||||
#values.append((batch_size, seqdim, 1024, 4*1024))
|
||||
|
@ -1800,7 +1803,7 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
B_fp4, state = F.quantize_fp4(B)
|
||||
B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True)
|
||||
|
||||
B_nf4, state_nf4= F.quantize_nf4(B)
|
||||
B_nf4, state_nf4 = F.quantize_nf4(B)
|
||||
|
||||
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half()
|
||||
linear8bit.eval()
|
||||
|
@ -1813,6 +1816,7 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
|
||||
linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
|
||||
linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
|
||||
F.cutlass3_gemm(A, B_nf4.t(), state=state_nf4)
|
||||
|
||||
# warmup
|
||||
for i in range(iters):
|
||||
|
@ -1844,7 +1848,8 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
|
||||
#bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
|
||||
F.cutlass3_gemm(A, B_nf4.t(), state=state_nf4)
|
||||
torch.cuda.synchronize()
|
||||
print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
|
||||
|
@ -2221,7 +2226,8 @@ def test_bench_dequantization():
|
|||
|
||||
|
||||
|
||||
def test_fp4_quant():
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
|
||||
def test_fp4_quant(dtype):
|
||||
vals = list(product([0, 1], repeat=4))
|
||||
|
||||
code = {}
|
||||
|
@ -2243,7 +2249,7 @@ def test_fp4_quant():
|
|||
result = sign*exp*frac
|
||||
code[idx] = result
|
||||
|
||||
A1 = torch.randn(1024, 1024, device='cuda').half()
|
||||
A1 = torch.randn(1024, 1024, device='cuda', dtype=dtype)
|
||||
qa, SA = F.quantize_fp4(A1, blocksize=64)
|
||||
A2 = F.dequantize_fp4(qa, SA)
|
||||
|
||||
|
@ -2252,7 +2258,7 @@ def test_fp4_quant():
|
|||
idx = err > 1.0
|
||||
err = err.mean()
|
||||
|
||||
|
||||
assert A2.dtype == dtype
|
||||
assert err.item() < 0.1
|
||||
assert relerr.item() < 0.28
|
||||
|
||||
|
@ -2409,20 +2415,16 @@ def test_cutlass3_gemm(dtype):
|
|||
print(dim, (max_err.item(), max_relerr.item()))
|
||||
|
||||
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=['fp16', 'bf16'])
|
||||
#@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16'])
|
||||
def test_gemm_4bit(dtype):
|
||||
print('')
|
||||
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
||||
#for dim in [4096, 5120, 6656, 8192]:
|
||||
#for dim in [32]:
|
||||
for dim in [2*4096]:
|
||||
#for dim in [5120]:
|
||||
#for dim in [6656]:
|
||||
#for dim in [4]:
|
||||
for dim in [64, 128, 256, 512, 1024, 2048, 4096]:
|
||||
errs = []
|
||||
relerrs = []
|
||||
max_err = 0
|
||||
max_relerr = 0
|
||||
|
||||
for i in range(100):
|
||||
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
|
||||
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
|
||||
|
@ -2443,14 +2445,13 @@ def test_gemm_4bit(dtype):
|
|||
qB, state = F.quantize_nf4(B)
|
||||
F.dequantize_nf4(qB, state)
|
||||
|
||||
#C3 = torch.matmul(A, B.t())
|
||||
#C2 = bnb.matmul_4bit(A, qB.t(), state)
|
||||
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
|
||||
C1 = bnb.matmul_4bit(A, qB.t(), state)
|
||||
C1 = torch.matmul(A, B.t())
|
||||
|
||||
#print(state)
|
||||
#print(qB)
|
||||
|
||||
|
||||
#print('')
|
||||
#print(A)
|
||||
#print(B)
|
||||
|
@ -2464,8 +2465,8 @@ def test_gemm_4bit(dtype):
|
|||
# tensor cores are non-deterministic
|
||||
# so we need to analyze errors around the mean
|
||||
# to test our implementation
|
||||
err = torch.abs(C1-C2)
|
||||
mag = torch.abs(C1)+1e-8
|
||||
err = torch.abs(C1-C2).float()
|
||||
mag = torch.abs(C1).float()+1e-5
|
||||
relerr = err/mag
|
||||
max_err = max(err.max(), max_err)
|
||||
max_relerr = max(relerr.max(), max_relerr)
|
||||
|
@ -2476,27 +2477,17 @@ def test_gemm_4bit(dtype):
|
|||
errs.append(err)
|
||||
relerrs.append(relerr)
|
||||
|
||||
if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
|
||||
print('')
|
||||
print(i, err, relerr)
|
||||
#print(A.flatten()[-6:])
|
||||
#print(B.flatten()[-6:])
|
||||
#out = A.flatten()[-6:]*B.flatten()[-6:]
|
||||
#print(out)
|
||||
#print(out[:-1].sum())
|
||||
print('='*80)
|
||||
#print(C1.flatten()[-6:])
|
||||
#print(C2.flatten()[-6:])
|
||||
#assert False, 'ERROR'
|
||||
|
||||
c = int(C1.numel()*0.0014*(dim/256))+1
|
||||
|
||||
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
|
||||
print(c/math.sqrt(dim))
|
||||
print('')
|
||||
print(dim, sum(errs)/len(errs)/math.sqrt(dim))
|
||||
print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
|
||||
print(dim, (max_err.item(), max_relerr.item()))
|
||||
#print('')
|
||||
#print(dim, sum(errs)/len(errs)/math.sqrt(dim))
|
||||
#print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
|
||||
#print(dim, (max_err.item(), max_relerr.item()))
|
||||
#print(sum(errs)/len(errs)/math.sqrt(dim) , 0.00015)
|
||||
#print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , 0.0015)
|
||||
assert sum(errs)/len(errs)/math.sqrt(dim) < 0.011
|
||||
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.15
|
||||
|
||||
@pytest.mark.skip("Row scale has some bugs for ampere")
|
||||
def test_managed():
|
||||
|
|
Loading…
Reference in New Issue
Block a user