Added bfloat16 quantizations and tests.
This commit is contained in:
parent
dfe6900b94
commit
02fd80cb81
|
@ -561,4 +561,7 @@ def matmul(
|
||||||
|
|
||||||
def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None):
|
def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None):
|
||||||
assert quant_state is not None
|
assert quant_state is not None
|
||||||
return MatMul4Bit.apply(A, B, out, bias, quant_state)
|
if A.numel() == A.shape[-1] and A.requires_grad == False:
|
||||||
|
return F.cutlass3_gemm(A, B.t(), out, state=quant_state)
|
||||||
|
else:
|
||||||
|
return MatMul4Bit.apply(A, B, out, bias, quant_state)
|
||||||
|
|
|
@ -617,6 +617,8 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou
|
||||||
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||||
elif A.dtype == torch.float16:
|
elif A.dtype == torch.float16:
|
||||||
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||||
|
elif A.dtype == torch.bfloat16:
|
||||||
|
lib.cquantize_blockwise_bf16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||||
post_call(A.device)
|
post_call(A.device)
|
||||||
|
@ -629,11 +631,9 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou
|
||||||
offset = absmax.mean()
|
offset = absmax.mean()
|
||||||
absmax -= offset
|
absmax -= offset
|
||||||
qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False)
|
qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False)
|
||||||
state = [qabsmax, code, blocksize, nested, offset, state2]
|
state = [qabsmax, code, blocksize, nested, A.dtype, offset, state2]
|
||||||
else:
|
else:
|
||||||
state = [absmax, code, blocksize, nested, None, None]
|
state = [absmax, code, blocksize, nested, A.dtype, None, None]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return out, state
|
return out, state
|
||||||
|
|
||||||
|
@ -678,18 +678,16 @@ def dequantize_blockwise(
|
||||||
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
|
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
|
||||||
code = name2qmap["dynamic"]
|
code = name2qmap["dynamic"]
|
||||||
|
|
||||||
if out is None:
|
|
||||||
out = torch.zeros_like(A, dtype=torch.float32)
|
|
||||||
|
|
||||||
if quant_state is None:
|
if quant_state is None:
|
||||||
quant_state = (absmax, code, blocksize)
|
quant_state = (absmax, code, blocksize, False, torch.float32, None, None)
|
||||||
assert absmax is not None and out is not None
|
|
||||||
else:
|
|
||||||
absmax, code, blocksize, nested, offset, state2 = quant_state
|
|
||||||
if nested:
|
|
||||||
absmax = dequantize_blockwise(absmax, state2)
|
|
||||||
absmax += offset
|
|
||||||
|
|
||||||
|
absmax, code, blocksize, nested, dtype, offset, state2 = quant_state
|
||||||
|
if nested:
|
||||||
|
absmax = dequantize_blockwise(absmax, state2)
|
||||||
|
absmax += offset
|
||||||
|
|
||||||
|
if out is None:
|
||||||
|
out = torch.empty(A.shape, dtype=dtype, device=A.device)
|
||||||
|
|
||||||
if A.device.type != 'cpu':
|
if A.device.type != 'cpu':
|
||||||
device = pre_call(A.device)
|
device = pre_call(A.device)
|
||||||
|
@ -701,6 +699,8 @@ def dequantize_blockwise(
|
||||||
lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
||||||
elif out.dtype == torch.float16:
|
elif out.dtype == torch.float16:
|
||||||
lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
||||||
|
elif out.dtype == torch.bfloat16:
|
||||||
|
lib.cdequantize_blockwise_bf16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||||
post_call(A.device)
|
post_call(A.device)
|
||||||
|
@ -774,6 +774,11 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
|
||||||
lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
||||||
else:
|
else:
|
||||||
lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
||||||
|
elif A.dtype == torch.bfloat16:
|
||||||
|
if quant_type == 'fp4':
|
||||||
|
lib.cquantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
||||||
|
else:
|
||||||
|
lib.cquantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||||
post_call(A.device)
|
post_call(A.device)
|
||||||
|
@ -860,6 +865,11 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
|
||||||
lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
||||||
else:
|
else:
|
||||||
lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
||||||
|
elif out.dtype == torch.bfloat16:
|
||||||
|
if quant_type == 'fp4':
|
||||||
|
lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
||||||
|
else:
|
||||||
|
lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||||
post_call(A.device)
|
post_call(A.device)
|
||||||
|
@ -1503,7 +1513,12 @@ def cutlass3_gemm(
|
||||||
ldc = ct.c_int32(ldc)
|
ldc = ct.c_int32(ldc)
|
||||||
|
|
||||||
if B.dtype == torch.uint8:
|
if B.dtype == torch.uint8:
|
||||||
lib.cgemm_4bit_inference_naive(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
if A.dtype == torch.float16:
|
||||||
|
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
||||||
|
elif A.dtype == torch.bfloat16:
|
||||||
|
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
|
||||||
elif A.dtype == torch.float32:
|
elif A.dtype == torch.float32:
|
||||||
lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc)
|
lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc)
|
||||||
elif A.dtype == torch.float16:
|
elif A.dtype == torch.float16:
|
||||||
|
@ -1515,7 +1530,6 @@ def cutlass3_gemm(
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def igemm(
|
def igemm(
|
||||||
A: Tensor,
|
A: Tensor,
|
||||||
B: Tensor,
|
B: Tensor,
|
||||||
|
|
|
@ -3540,7 +3540,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
|
||||||
unsigned char local_B_4bit[num_values_8bit];
|
unsigned char local_B_4bit[num_values_8bit];
|
||||||
T local_B[num_values_4bit];
|
T local_B[num_values_4bit];
|
||||||
T local_A[num_values_4bit];
|
T local_A[num_values_4bit];
|
||||||
__shared__ half quant_map[16*THREADS];
|
__shared__ T quant_map[16*THREADS];
|
||||||
|
|
||||||
for(int i = 0; i < 16; i++)
|
for(int i = 0; i < 16; i++)
|
||||||
quant_map[threadIdx.x + (i*blockDim.x)] = nf4_data[i];
|
quant_map[threadIdx.x + (i*blockDim.x)] = nf4_data[i];
|
||||||
|
@ -3769,11 +3769,8 @@ template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, ha
|
||||||
template __global__ void kgemm_4bit_inference<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
template __global__ void kgemm_4bit_inference<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||||
template __global__ void kgemm_4bit_inference<half, 256>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
template __global__ void kgemm_4bit_inference<half, 256>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||||
|
|
||||||
template __global__ void kgemm_4bit_inference_naive<half, 32>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
|
||||||
template __global__ void kgemm_4bit_inference_naive<half, 96>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
|
||||||
template __global__ void kgemm_4bit_inference_naive<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
template __global__ void kgemm_4bit_inference_naive<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||||
template __global__ void kgemm_4bit_inference_naive<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
|
||||||
template __global__ void kgemm_4bit_inference_naive<half, 256>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
|
||||||
|
|
||||||
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
||||||
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
||||||
|
@ -3929,6 +3926,20 @@ MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit)
|
||||||
MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit)
|
MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit)
|
||||||
MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit)
|
MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit)
|
||||||
MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit)
|
MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit)
|
||||||
|
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4)
|
||||||
|
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4)
|
||||||
|
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4)
|
||||||
|
MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4)
|
||||||
|
MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4)
|
||||||
|
MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4)
|
||||||
|
MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4)
|
||||||
|
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4)
|
||||||
|
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4)
|
||||||
|
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4)
|
||||||
|
MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4)
|
||||||
|
MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4)
|
||||||
|
MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4)
|
||||||
|
MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4)
|
||||||
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit)
|
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit)
|
||||||
MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit)
|
MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit)
|
||||||
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit)
|
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit)
|
||||||
|
@ -3937,13 +3948,6 @@ MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit)
|
||||||
MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit)
|
MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit)
|
||||||
MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit)
|
MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit)
|
||||||
MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit)
|
MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit)
|
||||||
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4)
|
|
||||||
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4)
|
|
||||||
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4)
|
|
||||||
MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4)
|
|
||||||
MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4)
|
|
||||||
MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4)
|
|
||||||
MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4)
|
|
||||||
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4)
|
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4)
|
||||||
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4)
|
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4)
|
||||||
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4)
|
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4)
|
||||||
|
@ -3951,13 +3955,6 @@ MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4)
|
||||||
MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4)
|
MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4)
|
||||||
MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4)
|
MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4)
|
||||||
MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4)
|
MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4)
|
||||||
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4)
|
|
||||||
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4)
|
|
||||||
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4)
|
|
||||||
MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4)
|
|
||||||
MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4)
|
|
||||||
MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4)
|
|
||||||
MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4)
|
|
||||||
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4)
|
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4)
|
||||||
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4)
|
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4)
|
||||||
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4)
|
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4)
|
||||||
|
@ -3966,12 +3963,38 @@ MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4)
|
||||||
MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4)
|
MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4)
|
||||||
MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4)
|
MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4)
|
||||||
|
|
||||||
|
MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, General8bit)
|
||||||
|
MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 1, General8bit)
|
||||||
|
MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, General8bit)
|
||||||
|
MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, General8bit)
|
||||||
|
MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, General8bit)
|
||||||
|
MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, General8bit)
|
||||||
|
MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, General8bit)
|
||||||
|
MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, General8bit)
|
||||||
|
MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, FP4)
|
||||||
|
MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, FP4)
|
||||||
|
MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, FP4)
|
||||||
|
MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, FP4)
|
||||||
|
MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, FP4)
|
||||||
|
MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, FP4)
|
||||||
|
MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, FP4)
|
||||||
|
MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, NF4)
|
||||||
|
MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, NF4)
|
||||||
|
MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, NF4)
|
||||||
|
MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, NF4)
|
||||||
|
MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, NF4)
|
||||||
|
MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, NF4)
|
||||||
|
MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, NF4)
|
||||||
|
|
||||||
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
|
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
|
||||||
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
|
||||||
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
|
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
|
||||||
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
|
||||||
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
|
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
|
||||||
|
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
||||||
|
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
||||||
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
||||||
|
template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n);
|
||||||
|
template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n);
|
||||||
|
template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n);
|
||||||
|
|
||||||
#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \
|
#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \
|
||||||
template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \
|
template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \
|
||||||
|
|
33
csrc/ops.cu
33
csrc/ops.cu
|
@ -733,20 +733,8 @@ template <typename T> void gemm_4bit_inference_naive(int m, int n, int k, T * A,
|
||||||
{
|
{
|
||||||
|
|
||||||
int num_blocks = (m+3)/4;
|
int num_blocks = (m+3)/4;
|
||||||
//int num_blocks = m;
|
|
||||||
|
|
||||||
cout << num_blocks << endl;
|
|
||||||
//cout << lda << endl;
|
|
||||||
//cout << ldb << endl;
|
|
||||||
//cout << ldc << endl;
|
|
||||||
|
|
||||||
//cout << m << endl;
|
|
||||||
//cout << n << endl;
|
|
||||||
//cout << k << endl;
|
|
||||||
kgemm_4bit_inference_naive<T, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
kgemm_4bit_inference_naive<T, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||||
//kgemm_4bit_inference<T, 256><<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
|
||||||
//kgemm_4bit_inference<T, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
|
||||||
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int FUNC> void func(T *A, T *B, T value, long n)
|
template <typename T, int FUNC> void func(T *A, T *B, T value, long n)
|
||||||
|
@ -770,6 +758,7 @@ template void func<float, _MUL>(float *A, float *B, float value, long n);
|
||||||
|
|
||||||
template void gemm_4bit_inference<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
template void gemm_4bit_inference<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||||
template void gemm_4bit_inference_naive<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
template void gemm_4bit_inference_naive<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||||
|
template void gemm_4bit_inference_naive<__nv_bfloat16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
|
||||||
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
|
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
|
||||||
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
|
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
|
||||||
template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
|
template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
|
||||||
|
@ -796,19 +785,27 @@ template void estimateQuantiles(half *A, float *code, float offset, int n);
|
||||||
template void estimateQuantiles(float *A, float *code, float offset, int n);
|
template void estimateQuantiles(float *A, float *code, float offset, int n);
|
||||||
|
|
||||||
template void quantizeBlockwise<half, 1, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
template void quantizeBlockwise<half, 1, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||||
template void quantizeBlockwise<float, 1, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
|
||||||
template void quantizeBlockwise<half, 0, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
template void quantizeBlockwise<half, 0, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||||
template void quantizeBlockwise<float, 0, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
|
||||||
template void quantizeBlockwise<half, 0, FP4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
template void quantizeBlockwise<half, 0, FP4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||||
template void quantizeBlockwise<float, 0, FP4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
|
||||||
template void quantizeBlockwise<half, 0, NF4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
template void quantizeBlockwise<half, 0, NF4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||||
|
template void quantizeBlockwise<float, 1, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||||
|
template void quantizeBlockwise<float, 0, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||||
|
template void quantizeBlockwise<float, 0, FP4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||||
template void quantizeBlockwise<float, 0, NF4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
template void quantizeBlockwise<float, 0, NF4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||||
template void dequantizeBlockwise<half, General8bit>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
template void quantizeBlockwise<__nv_bfloat16, 1, General8bit>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||||
|
template void quantizeBlockwise<__nv_bfloat16, 0, General8bit>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||||
|
template void quantizeBlockwise<__nv_bfloat16, 0, FP4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||||
|
template void quantizeBlockwise<__nv_bfloat16, 0, NF4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||||
|
|
||||||
template void dequantizeBlockwise<float, General8bit>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
template void dequantizeBlockwise<float, General8bit>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||||
template void dequantizeBlockwise<half, FP4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
|
||||||
template void dequantizeBlockwise<float, FP4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
template void dequantizeBlockwise<float, FP4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||||
template void dequantizeBlockwise<half, NF4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
|
||||||
template void dequantizeBlockwise<float, NF4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
template void dequantizeBlockwise<float, NF4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||||
|
template void dequantizeBlockwise<half, General8bit>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||||
|
template void dequantizeBlockwise<half, FP4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||||
|
template void dequantizeBlockwise<half, NF4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||||
|
template void dequantizeBlockwise<__nv_bfloat16, General8bit>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n);
|
||||||
|
template void dequantizeBlockwise<__nv_bfloat16, FP4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n);
|
||||||
|
template void dequantizeBlockwise<__nv_bfloat16, NF4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n);
|
||||||
|
|
||||||
#define MAKE_optimizer32bit(name, gtype) \
|
#define MAKE_optimizer32bit(name, gtype) \
|
||||||
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
|
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
|
||||||
|
|
|
@ -28,9 +28,12 @@ void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int l
|
||||||
void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||||
{ gemm_4bit_inference<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
{ gemm_4bit_inference<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||||
|
|
||||||
void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||||
{ gemm_4bit_inference_naive<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
{ gemm_4bit_inference_naive<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||||
|
|
||||||
|
void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize)
|
||||||
|
{ gemm_4bit_inference_naive<__nv_bfloat16>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||||
|
|
||||||
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
|
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
|
||||||
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \
|
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \
|
||||||
|
|
||||||
|
@ -106,19 +109,29 @@ void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){
|
||||||
void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping<half>(g, gnorm_vec, step, n); }
|
void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping<half>(g, gnorm_vec, step, n); }
|
||||||
|
|
||||||
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||||
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
|
||||||
void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||||
void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
|
||||||
void quantizeBlockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
void quantizeBlockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||||
|
|
||||||
|
void quantizeBlockwise_bf16(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||||
|
void quantizeBlockwise_bf16_fp4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||||
|
void quantizeBlockwise_bf16_nf4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||||
|
|
||||||
|
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||||
|
void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||||
void quantizeBlockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
void quantizeBlockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||||
|
|
||||||
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, General8bit>(code, A, absmax, out, blocksize, n); } \
|
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, General8bit>(code, A, absmax, out, blocksize, n); } \
|
||||||
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, General8bit>(code, A, absmax, out, blocksize, n); }
|
|
||||||
void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, FP4>(NULL, A, absmax, out, blocksize, n); } \
|
void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, FP4>(NULL, A, absmax, out, blocksize, n); } \
|
||||||
void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, FP4>(NULL, A, absmax, out, blocksize, n); }
|
|
||||||
void dequantizeBlockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, NF4>(NULL, A, absmax, out, blocksize, n); } \
|
void dequantizeBlockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, NF4>(NULL, A, absmax, out, blocksize, n); } \
|
||||||
|
|
||||||
|
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, General8bit>(code, A, absmax, out, blocksize, n); }
|
||||||
|
void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, FP4>(NULL, A, absmax, out, blocksize, n); }
|
||||||
void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, NF4>(NULL, A, absmax, out, blocksize, n); }
|
void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, NF4>(NULL, A, absmax, out, blocksize, n); }
|
||||||
|
|
||||||
|
void dequantizeBlockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise<__nv_bfloat16, General8bit>(code, A, absmax, out, blocksize, n); }
|
||||||
|
void dequantizeBlockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n); }
|
||||||
|
void dequantizeBlockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n); }
|
||||||
|
|
||||||
|
|
||||||
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
|
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
|
||||||
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
|
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
|
||||||
|
@ -177,21 +190,31 @@ extern "C"
|
||||||
void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); }
|
void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); }
|
||||||
void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); }
|
void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); }
|
||||||
void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); }
|
void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); }
|
||||||
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
|
||||||
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
|
||||||
|
|
||||||
void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
|
||||||
void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
|
||||||
|
|
||||||
void cquantize_blockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); }
|
|
||||||
void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); }
|
|
||||||
void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); }
|
void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); }
|
||||||
void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); }
|
void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
||||||
void cquantize_blockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); }
|
|
||||||
void cquantize_blockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); }
|
|
||||||
void cdequantize_blockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); }
|
void cdequantize_blockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); }
|
||||||
|
|
||||||
|
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
||||||
|
void cquantize_blockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); }
|
||||||
|
void cquantize_blockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); }
|
||||||
|
|
||||||
|
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
||||||
|
void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); }
|
||||||
|
void cquantize_blockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); }
|
||||||
|
|
||||||
|
void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
||||||
|
void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); }
|
||||||
void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); }
|
void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); }
|
||||||
|
|
||||||
|
void cquantize_blockwise_bf16(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); }
|
||||||
|
void cquantize_blockwise_bf16_fp4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); }
|
||||||
|
void cquantize_blockwise_bf16_nf4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); }
|
||||||
|
|
||||||
|
void cdequantize_blockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); }
|
||||||
|
void cdequantize_blockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); }
|
||||||
|
void cdequantize_blockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); }
|
||||||
|
|
||||||
#define MAKE_CFUNC32(name, gtype, gbits) \
|
#define MAKE_CFUNC32(name, gtype, gbits) \
|
||||||
void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \
|
void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \
|
||||||
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
|
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
|
||||||
|
@ -348,9 +371,6 @@ extern "C"
|
||||||
void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||||
{ gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
{ gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||||
|
|
||||||
void cgemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
|
||||||
{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
|
||||||
|
|
||||||
void *cget_managed_ptr(size_t bytes)
|
void *cget_managed_ptr(size_t bytes)
|
||||||
{
|
{
|
||||||
void *ptr;
|
void *ptr;
|
||||||
|
@ -374,6 +394,12 @@ extern "C"
|
||||||
CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE)
|
CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE)
|
||||||
CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)
|
CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)
|
||||||
|
|
||||||
|
void cgemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||||
|
{ gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||||
|
|
||||||
|
void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize)
|
||||||
|
{ gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
|
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
|
||||||
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); }
|
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); }
|
||||||
|
|
|
@ -154,34 +154,36 @@ def test_dynamic_quantization():
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
|
||||||
@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"])
|
@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"])
|
||||||
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
|
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
|
||||||
def test_dynamic_blockwise_quantization(nested, blocksize):
|
def test_dynamic_blockwise_quantization(dtype, nested, blocksize):
|
||||||
#print('')
|
#print('')
|
||||||
diffs = []
|
diffs = []
|
||||||
reldiffs = []
|
reldiffs = []
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
A1 = torch.randn(1024, 1024, device="cuda")
|
A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
|
||||||
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
|
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
|
||||||
A2 = F.dequantize_blockwise(C, S)
|
A2 = F.dequantize_blockwise(C, S)
|
||||||
diff = torch.abs(A1 - A2)
|
diff = torch.abs(A1 - A2).float()
|
||||||
reldiff = diff / torch.abs(A1 + 1e-8)
|
reldiff = diff / torch.abs(A1.float() + 1e-8)
|
||||||
diffs.append(diff.mean().item())
|
diffs.append(diff.mean().item())
|
||||||
reldiffs.append(reldiff.mean().item())
|
reldiffs.append(reldiff.mean().item())
|
||||||
abserr = sum(diffs)/len(diffs)
|
abserr = sum(diffs)/len(diffs)
|
||||||
relerr = sum(reldiffs)/len(reldiffs)
|
relerr = sum(reldiffs)/len(reldiffs)
|
||||||
|
#print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
|
||||||
|
#print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
|
||||||
assert abserr < 0.011
|
assert abserr < 0.011
|
||||||
assert relerr < 0.018
|
assert relerr < 0.018
|
||||||
#print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs))
|
assert A2.dtype == dtype
|
||||||
#print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs))
|
|
||||||
|
|
||||||
diffs = []
|
diffs = []
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
A1 = torch.rand(1024, 1024, device="cuda")
|
A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype)
|
||||||
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
|
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
|
||||||
A2 = F.dequantize_blockwise(C, S)
|
A2 = F.dequantize_blockwise(C, S)
|
||||||
diff = torch.abs(A1 - A2)
|
diff = torch.abs(A1 - A2).float()
|
||||||
reldiff = diff / torch.abs(A1 + 1e-8)
|
reldiff = diff / torch.abs(A1.float() + 1e-8)
|
||||||
diffs.append(diff.mean().item())
|
diffs.append(diff.mean().item())
|
||||||
reldiffs.append(reldiff.mean().item())
|
reldiffs.append(reldiff.mean().item())
|
||||||
#torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
|
#torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
|
||||||
|
@ -189,6 +191,7 @@ def test_dynamic_blockwise_quantization(nested, blocksize):
|
||||||
relerr = sum(reldiffs)/len(reldiffs)
|
relerr = sum(reldiffs)/len(reldiffs)
|
||||||
assert abserr < 0.0035
|
assert abserr < 0.0035
|
||||||
assert relerr < 0.015
|
assert relerr < 0.015
|
||||||
|
assert A2.dtype == dtype
|
||||||
#print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
|
#print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
|
||||||
#print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
|
#print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
|
||||||
|
|
||||||
|
@ -1773,8 +1776,8 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
|
||||||
print("partial matmul", time.time() - t0)
|
print("partial matmul", time.time() - t0)
|
||||||
|
|
||||||
|
|
||||||
batch_size = 32
|
batch_size = 1
|
||||||
seqdim = 512+256
|
seqdim = 1
|
||||||
values = []
|
values = []
|
||||||
#values.append((batch_size, seqdim, 768, 4 * 768))
|
#values.append((batch_size, seqdim, 768, 4 * 768))
|
||||||
#values.append((batch_size, seqdim, 1024, 4*1024))
|
#values.append((batch_size, seqdim, 1024, 4*1024))
|
||||||
|
@ -1800,7 +1803,7 @@ def test_bench_matmul(batch, seq, model, hidden):
|
||||||
B_fp4, state = F.quantize_fp4(B)
|
B_fp4, state = F.quantize_fp4(B)
|
||||||
B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True)
|
B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True)
|
||||||
|
|
||||||
B_nf4, state_nf4= F.quantize_nf4(B)
|
B_nf4, state_nf4 = F.quantize_nf4(B)
|
||||||
|
|
||||||
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half()
|
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half()
|
||||||
linear8bit.eval()
|
linear8bit.eval()
|
||||||
|
@ -1813,6 +1816,7 @@ def test_bench_matmul(batch, seq, model, hidden):
|
||||||
|
|
||||||
linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
|
linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
|
||||||
linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
|
linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
|
||||||
|
F.cutlass3_gemm(A, B_nf4.t(), state=state_nf4)
|
||||||
|
|
||||||
# warmup
|
# warmup
|
||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
|
@ -1844,7 +1848,8 @@ def test_bench_matmul(batch, seq, model, hidden):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
|
#bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
|
||||||
|
F.cutlass3_gemm(A, B_nf4.t(), state=state_nf4)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||||
|
|
||||||
|
@ -2221,7 +2226,8 @@ def test_bench_dequantization():
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_fp4_quant():
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
|
||||||
|
def test_fp4_quant(dtype):
|
||||||
vals = list(product([0, 1], repeat=4))
|
vals = list(product([0, 1], repeat=4))
|
||||||
|
|
||||||
code = {}
|
code = {}
|
||||||
|
@ -2243,7 +2249,7 @@ def test_fp4_quant():
|
||||||
result = sign*exp*frac
|
result = sign*exp*frac
|
||||||
code[idx] = result
|
code[idx] = result
|
||||||
|
|
||||||
A1 = torch.randn(1024, 1024, device='cuda').half()
|
A1 = torch.randn(1024, 1024, device='cuda', dtype=dtype)
|
||||||
qa, SA = F.quantize_fp4(A1, blocksize=64)
|
qa, SA = F.quantize_fp4(A1, blocksize=64)
|
||||||
A2 = F.dequantize_fp4(qa, SA)
|
A2 = F.dequantize_fp4(qa, SA)
|
||||||
|
|
||||||
|
@ -2252,7 +2258,7 @@ def test_fp4_quant():
|
||||||
idx = err > 1.0
|
idx = err > 1.0
|
||||||
err = err.mean()
|
err = err.mean()
|
||||||
|
|
||||||
|
assert A2.dtype == dtype
|
||||||
assert err.item() < 0.1
|
assert err.item() < 0.1
|
||||||
assert relerr.item() < 0.28
|
assert relerr.item() < 0.28
|
||||||
|
|
||||||
|
@ -2409,20 +2415,16 @@ def test_cutlass3_gemm(dtype):
|
||||||
print(dim, (max_err.item(), max_relerr.item()))
|
print(dim, (max_err.item(), max_relerr.item()))
|
||||||
|
|
||||||
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=['fp16', 'bf16'])
|
||||||
|
#@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16'])
|
||||||
def test_gemm_4bit(dtype):
|
def test_gemm_4bit(dtype):
|
||||||
print('')
|
print('')
|
||||||
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
for dim in [64, 128, 256, 512, 1024, 2048, 4096]:
|
||||||
#for dim in [4096, 5120, 6656, 8192]:
|
|
||||||
#for dim in [32]:
|
|
||||||
for dim in [2*4096]:
|
|
||||||
#for dim in [5120]:
|
|
||||||
#for dim in [6656]:
|
|
||||||
#for dim in [4]:
|
|
||||||
errs = []
|
errs = []
|
||||||
relerrs = []
|
relerrs = []
|
||||||
max_err = 0
|
max_err = 0
|
||||||
max_relerr = 0
|
max_relerr = 0
|
||||||
|
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
|
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
|
||||||
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
|
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
|
||||||
|
@ -2443,14 +2445,13 @@ def test_gemm_4bit(dtype):
|
||||||
qB, state = F.quantize_nf4(B)
|
qB, state = F.quantize_nf4(B)
|
||||||
F.dequantize_nf4(qB, state)
|
F.dequantize_nf4(qB, state)
|
||||||
|
|
||||||
#C3 = torch.matmul(A, B.t())
|
#C2 = bnb.matmul_4bit(A, qB.t(), state)
|
||||||
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
|
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
|
||||||
C1 = bnb.matmul_4bit(A, qB.t(), state)
|
C1 = torch.matmul(A, B.t())
|
||||||
|
|
||||||
#print(state)
|
#print(state)
|
||||||
#print(qB)
|
#print(qB)
|
||||||
|
|
||||||
|
|
||||||
#print('')
|
#print('')
|
||||||
#print(A)
|
#print(A)
|
||||||
#print(B)
|
#print(B)
|
||||||
|
@ -2464,8 +2465,8 @@ def test_gemm_4bit(dtype):
|
||||||
# tensor cores are non-deterministic
|
# tensor cores are non-deterministic
|
||||||
# so we need to analyze errors around the mean
|
# so we need to analyze errors around the mean
|
||||||
# to test our implementation
|
# to test our implementation
|
||||||
err = torch.abs(C1-C2)
|
err = torch.abs(C1-C2).float()
|
||||||
mag = torch.abs(C1)+1e-8
|
mag = torch.abs(C1).float()+1e-5
|
||||||
relerr = err/mag
|
relerr = err/mag
|
||||||
max_err = max(err.max(), max_err)
|
max_err = max(err.max(), max_err)
|
||||||
max_relerr = max(relerr.max(), max_relerr)
|
max_relerr = max(relerr.max(), max_relerr)
|
||||||
|
@ -2476,27 +2477,17 @@ def test_gemm_4bit(dtype):
|
||||||
errs.append(err)
|
errs.append(err)
|
||||||
relerrs.append(relerr)
|
relerrs.append(relerr)
|
||||||
|
|
||||||
if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
|
|
||||||
print('')
|
|
||||||
print(i, err, relerr)
|
|
||||||
#print(A.flatten()[-6:])
|
|
||||||
#print(B.flatten()[-6:])
|
|
||||||
#out = A.flatten()[-6:]*B.flatten()[-6:]
|
|
||||||
#print(out)
|
|
||||||
#print(out[:-1].sum())
|
|
||||||
print('='*80)
|
|
||||||
#print(C1.flatten()[-6:])
|
|
||||||
#print(C2.flatten()[-6:])
|
|
||||||
#assert False, 'ERROR'
|
|
||||||
|
|
||||||
c = int(C1.numel()*0.0014*(dim/256))+1
|
c = int(C1.numel()*0.0014*(dim/256))+1
|
||||||
|
|
||||||
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
|
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
|
||||||
print(c/math.sqrt(dim))
|
#print('')
|
||||||
print('')
|
#print(dim, sum(errs)/len(errs)/math.sqrt(dim))
|
||||||
print(dim, sum(errs)/len(errs)/math.sqrt(dim))
|
#print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
|
||||||
print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
|
#print(dim, (max_err.item(), max_relerr.item()))
|
||||||
print(dim, (max_err.item(), max_relerr.item()))
|
#print(sum(errs)/len(errs)/math.sqrt(dim) , 0.00015)
|
||||||
|
#print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , 0.0015)
|
||||||
|
assert sum(errs)/len(errs)/math.sqrt(dim) < 0.011
|
||||||
|
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.15
|
||||||
|
|
||||||
@pytest.mark.skip("Row scale has some bugs for ampere")
|
@pytest.mark.skip("Row scale has some bugs for ampere")
|
||||||
def test_managed():
|
def test_managed():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user