Merge remote-tracking branch 'origin/inference'
This commit is contained in:
commit
5f492d437e
4
Makefile
4
Makefile
|
@ -47,8 +47,8 @@ CC_cublasLt110 := -gencode arch=compute_75,code=sm_75
|
|||
CC_cublasLt110 += -gencode arch=compute_80,code=sm_80
|
||||
|
||||
CC_cublasLt111 := -gencode arch=compute_75,code=sm_75
|
||||
CC_cublasLt111 += -gencode arch=compute_80,code=sm_80
|
||||
CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
|
||||
#CC_cublasLt111 += -gencode arch=compute_80,code=sm_80
|
||||
#CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
|
||||
|
||||
CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89
|
||||
CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90
|
||||
|
|
|
@ -512,7 +512,7 @@ class MatMul4Bit(torch.autograd.Function):
|
|||
|
||||
# 1. Dequantize
|
||||
# 2. MatmulnN
|
||||
output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype).t(), bias)
|
||||
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, state).to(A.dtype).t(), bias)
|
||||
|
||||
# 3. Save state
|
||||
ctx.state = state
|
||||
|
@ -543,7 +543,7 @@ class MatMul4Bit(torch.autograd.Function):
|
|||
|
||||
# not supported by PyTorch. TODO: create work-around
|
||||
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
|
||||
if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(grad_output.dtype).t())
|
||||
if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())
|
||||
|
||||
return grad_A, grad_B, None, grad_bias, None
|
||||
|
||||
|
@ -564,4 +564,7 @@ def matmul(
|
|||
|
||||
def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None):
|
||||
assert quant_state is not None
|
||||
return MatMul4Bit.apply(A, B, out, bias, quant_state)
|
||||
if A.numel() == A.shape[-1] and A.requires_grad == False:
|
||||
return F.gemv_4bit(A, B.t(), out, state=quant_state)
|
||||
else:
|
||||
return MatMul4Bit.apply(A, B, out, bias, quant_state)
|
||||
|
|
|
@ -240,17 +240,19 @@ def create_normal_map(offset=0.9677083, use_extra_value=True):
|
|||
v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist()
|
||||
v2 = [0]*(256-15) ## we have 15 non-zero values in this data type
|
||||
v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist()
|
||||
v = v1 + v2 + v3
|
||||
else:
|
||||
v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist()
|
||||
v2 = [0]*(256-14) ## we have 14 non-zero values in this data type
|
||||
v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist()
|
||||
v = v1 + v2 + v3
|
||||
|
||||
v = v1 + v2 + v3
|
||||
|
||||
values = torch.Tensor(v)
|
||||
values = values.sort().values
|
||||
values /= values.max()
|
||||
|
||||
assert values.numel() == 256
|
||||
|
||||
return values
|
||||
|
||||
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8):
|
||||
|
@ -617,6 +619,8 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou
|
|||
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||
elif A.dtype == torch.float16:
|
||||
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||
elif A.dtype == torch.bfloat16:
|
||||
lib.cquantize_blockwise_bf16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||
else:
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
post_call(A.device)
|
||||
|
@ -629,11 +633,9 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou
|
|||
offset = absmax.mean()
|
||||
absmax -= offset
|
||||
qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False)
|
||||
state = [qabsmax, code, blocksize, nested, offset, state2]
|
||||
state = [qabsmax, code, blocksize, nested, A.dtype, offset, state2]
|
||||
else:
|
||||
state = [absmax, code, blocksize, nested, None, None]
|
||||
|
||||
|
||||
state = [absmax, code, blocksize, nested, A.dtype, None, None]
|
||||
|
||||
return out, state
|
||||
|
||||
|
@ -678,18 +680,16 @@ def dequantize_blockwise(
|
|||
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
|
||||
code = name2qmap["dynamic"]
|
||||
|
||||
if out is None:
|
||||
out = torch.zeros_like(A, dtype=torch.float32)
|
||||
|
||||
if quant_state is None:
|
||||
quant_state = (absmax, code, blocksize)
|
||||
assert absmax is not None and out is not None
|
||||
else:
|
||||
absmax, code, blocksize, nested, offset, state2 = quant_state
|
||||
if nested:
|
||||
absmax = dequantize_blockwise(absmax, state2)
|
||||
absmax += offset
|
||||
quant_state = (absmax, code, blocksize, False, torch.float32, None, None)
|
||||
|
||||
absmax, code, blocksize, nested, dtype, offset, state2 = quant_state
|
||||
if nested:
|
||||
absmax = dequantize_blockwise(absmax, state2)
|
||||
absmax += offset
|
||||
|
||||
if out is None:
|
||||
out = torch.empty(A.shape, dtype=dtype, device=A.device)
|
||||
|
||||
if A.device.type != 'cpu':
|
||||
device = pre_call(A.device)
|
||||
|
@ -701,6 +701,8 @@ def dequantize_blockwise(
|
|||
lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
||||
elif out.dtype == torch.float16:
|
||||
lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
||||
elif out.dtype == torch.bfloat16:
|
||||
lib.cdequantize_blockwise_bf16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
||||
else:
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
post_call(A.device)
|
||||
|
@ -710,6 +712,47 @@ def dequantize_blockwise(
|
|||
|
||||
return out
|
||||
|
||||
def get_4bit_type(typename, device=None, blocksize=64):
|
||||
if device is None: device = 'cuda'
|
||||
data = None
|
||||
if typename == 'nf4':
|
||||
data = [-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635,
|
||||
-0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725,
|
||||
0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941,
|
||||
0.7229568362236023, 1.0]
|
||||
elif typename == 'fp4':
|
||||
# 0b000 = 0
|
||||
# 0b001 = 0.0625
|
||||
# 0b010 = 8
|
||||
# 0b011 = 12
|
||||
# 0b100 = 4
|
||||
# 0b101 = 6
|
||||
# 0b110 = 2
|
||||
# 0b111 = 3
|
||||
data = [0, 0.0625, 8.0, 12.0, 4.0, 6.0, 2.0, 3.0, -0, -0.0625, -8.0, -12.0, -4.0, -6.0, -2.0, -3.0]
|
||||
elif typename == 'int4':
|
||||
data = [7, 6, 5, 4, 3, 2, 1, 0, -0, -1, -2, -3, -4, -5, -6, -7]
|
||||
elif typename == 'af4':
|
||||
# Taken from: NF4 Isn't Information Theoretically Optimal (and that's Good)
|
||||
# https://arxiv.org/abs/2306.06965
|
||||
if blocksize == 64:
|
||||
data = [-1., -0.69441008, -0.51243739, -0.3736951, -0.25607552, -0.14982478,
|
||||
-0.04934812, 0., 0.04273164, 0.12934483, 0.21961274, 0.31675666,
|
||||
0.42563882, 0.55496234, 0.72424863, 1.][::-1]
|
||||
else:
|
||||
raise NotImplementedError(f'4-bit AbnormalFloats currently only support blocksize 64.')
|
||||
|
||||
if data is None:
|
||||
raise NotImplementedError(f'Typename {typename} not supported')
|
||||
|
||||
data = Tensor(data)
|
||||
data /= data.abs().max()
|
||||
assert data.numel() == 16
|
||||
|
||||
return data.to(device)
|
||||
|
||||
|
||||
|
||||
def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False):
|
||||
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4')
|
||||
|
||||
|
@ -774,20 +817,25 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
|
|||
lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
||||
else:
|
||||
lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
||||
elif A.dtype == torch.bfloat16:
|
||||
if quant_type == 'fp4':
|
||||
lib.cquantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
||||
else:
|
||||
lib.cquantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
||||
else:
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
post_call(A.device)
|
||||
|
||||
datatype = get_4bit_type(quant_type, device=A.device)
|
||||
|
||||
if compress_statistics:
|
||||
offset = absmax.mean()
|
||||
absmax -= offset
|
||||
#code = create_custom_map().to(absmax.device)
|
||||
#qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256)
|
||||
qabsmax, state2 = quantize_blockwise(absmax, blocksize=256)
|
||||
del absmax
|
||||
state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type]
|
||||
state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type, datatype]
|
||||
else:
|
||||
state = [absmax, input_shape, A.dtype, blocksize, None, quant_type]
|
||||
state = [absmax, input_shape, A.dtype, blocksize, None, quant_type, datatype]
|
||||
|
||||
return out, state
|
||||
|
||||
|
@ -834,7 +882,7 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
|
|||
shape = out.shape
|
||||
dtype = out.dtype
|
||||
else:
|
||||
absmax, shape, dtype, blocksize, compressed_stats, quant_type = quant_state
|
||||
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state
|
||||
|
||||
|
||||
if compressed_stats is not None:
|
||||
|
@ -860,6 +908,11 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
|
|||
lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
||||
else:
|
||||
lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
||||
elif out.dtype == torch.bfloat16:
|
||||
if quant_type == 'fp4':
|
||||
lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
||||
else:
|
||||
lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
||||
else:
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
post_call(A.device)
|
||||
|
@ -1398,7 +1451,7 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
|
|||
|
||||
return sout
|
||||
|
||||
def cutlass3_gemm(
|
||||
def gemv_4bit(
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
out: Tensor = None,
|
||||
|
@ -1406,95 +1459,35 @@ def cutlass3_gemm(
|
|||
transposed_B=False,
|
||||
state=None
|
||||
):
|
||||
prev_device = pre_call(A.device)
|
||||
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
|
||||
if state is None:
|
||||
Bshape = B.shape
|
||||
bout = Bshape[1]
|
||||
else:
|
||||
Bshape = state[1]
|
||||
bout = Bshape[0]
|
||||
raise ValueError(f'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )')
|
||||
|
||||
if A.numel() != A.shape[-1]:
|
||||
raise ValueError(f'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]')
|
||||
|
||||
Bshape = state[1]
|
||||
bout = Bshape[0]
|
||||
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = state
|
||||
if compressed_stats is not None:
|
||||
offset, state2 = compressed_stats
|
||||
absmax = dequantize_blockwise(absmax, state2)
|
||||
absmax += offset
|
||||
|
||||
if out is None:
|
||||
out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device)
|
||||
|
||||
sA = A.shape
|
||||
sB = B.shape
|
||||
if transposed_A and len(sA) == 2:
|
||||
sA = (sA[1], sA[0])
|
||||
elif transposed_A and len(sA) == 3:
|
||||
sA = (sA[0], sA[2], sA[0])
|
||||
if transposed_B and len(sB) == 2:
|
||||
sB = (sB[1], sB[0])
|
||||
elif transposed_B and len(sB) == 3:
|
||||
sB = (sB[0], sB[2], sB[0])
|
||||
# this is a mess: cuBLAS expect column major, but PyTorch is row major.
|
||||
# So to perform the matrix multiplication, we have to treat A, B, and C matrices
|
||||
# (transpose of row major is column major)
|
||||
# This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these
|
||||
|
||||
# matrices in the input arguments for cuBLAS
|
||||
# column major: A @ B = C: [m, k] @ [k, n] = [m, n]
|
||||
# row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
|
||||
# column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
|
||||
if len(sB) == 2:
|
||||
if B.stride()[0] == B.shape[1]:
|
||||
transposed_B = False
|
||||
elif B.stride()[1] == B.shape[0]:
|
||||
transposed_B = True
|
||||
if len(A.shape) == 2:
|
||||
if A.stride()[0] == A.shape[1]:
|
||||
transposed_A = False
|
||||
elif A.stride()[1] == A.shape[0]:
|
||||
transposed_A = True
|
||||
if len(A.shape) == 3:
|
||||
out = torch.empty(size=(A.shape[0], A.shape[1], bout), dtype=A.dtype, device=A.device)
|
||||
else:
|
||||
if A.stride()[1] == A.shape[2]:
|
||||
transposed_A = False
|
||||
elif A.stride()[2] == A.shape[1]:
|
||||
transposed_A = True
|
||||
out = torch.empty(size=(A.shape[0], bout), dtype=A.dtype, device=A.device)
|
||||
|
||||
if len(sA) == 2:
|
||||
n = sA[0]
|
||||
ldb = A.stride()[1 if transposed_A else 0]
|
||||
elif len(sA) == 3 and len(sB) == 2:
|
||||
n = sA[0] * sA[1]
|
||||
ldb = sA[2]
|
||||
|
||||
m = sB[1]
|
||||
k = sB[0]
|
||||
lda = B.stride()[0]
|
||||
ldc = sB[1]
|
||||
elif len(sB) == 3:
|
||||
# special case
|
||||
assert len(sA) == 3
|
||||
if not (sA[0] == sB[0] and sA[1] == sB[1]):
|
||||
raise ValueError(
|
||||
f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}"
|
||||
)
|
||||
|
||||
transposed_A = True
|
||||
transposed_B = False
|
||||
|
||||
m = sB[2]
|
||||
n = sA[2]
|
||||
k = sB[0] * sB[1]
|
||||
|
||||
lda = n
|
||||
ldb = sA[2]
|
||||
ldc = m
|
||||
|
||||
ptr = CUBLAS_Context.get_instance().get_context(A.device)
|
||||
|
||||
# B^T @ A^T = C^T
|
||||
# [km, nk -> mn]
|
||||
#lda = ldb = ldc = 1
|
||||
#lda = 1
|
||||
if state is not None:
|
||||
m = Bshape[0]
|
||||
k = Bshape[1]
|
||||
lda = Bshape[0]
|
||||
ldc = Bshape[0]
|
||||
ldb = (ldb+1)//2
|
||||
#print(m, n, k, lda, ldb, ldc)
|
||||
is_on_gpu([B, A, out])
|
||||
n = 1
|
||||
m = Bshape[0]
|
||||
k = Bshape[1]
|
||||
lda = Bshape[0]
|
||||
ldc = Bshape[0]
|
||||
ldb = (A.shape[-1]+1)//2
|
||||
is_on_gpu([B, A, out, absmax, state[-1]])
|
||||
m = ct.c_int32(m)
|
||||
n = ct.c_int32(n)
|
||||
k = ct.c_int32(k)
|
||||
|
@ -1503,16 +1496,20 @@ def cutlass3_gemm(
|
|||
ldc = ct.c_int32(ldc)
|
||||
|
||||
if B.dtype == torch.uint8:
|
||||
lib.cgemm_4bit_inference(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
||||
elif A.dtype == torch.float32:
|
||||
lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc)
|
||||
elif A.dtype == torch.float16:
|
||||
lib.cgemm_host_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc)
|
||||
if A.dtype == torch.float16:
|
||||
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
||||
elif A.dtype == torch.bfloat16:
|
||||
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
||||
elif A.dtype == torch.float32:
|
||||
lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
||||
else:
|
||||
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
|
||||
else:
|
||||
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
|
||||
|
||||
return out
|
||||
post_call(prev_device)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -190,9 +190,9 @@ class Params4bit(torch.nn.Parameter):
|
|||
#s[-2][1][0] = s[-2][1][0].to(device) # nested absmax
|
||||
|
||||
# for 8-bit
|
||||
s[-2][0] = s[-2][0].to(device) # offset
|
||||
s[-2][1][0] = s[-2][1][0].to(device) # nested quantiation state statitics
|
||||
s[-2][1][1] = s[-2][1][1].to(device) # nested quantiation codebook
|
||||
s[-3][0] = s[-3][0].to(device) # offset
|
||||
s[-3][1][0] = s[-3][1][0].to(device) # nested quantiation state statitics
|
||||
s[-3][1][1] = s[-3][1][1].to(device) # nested quantiation codebook
|
||||
new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking),
|
||||
requires_grad=self.requires_grad, quant_state=self.quant_state,
|
||||
blocksize=self.blocksize, compress_statistics=self.compress_statistics,
|
||||
|
|
244
csrc/kernels.cu
244
csrc/kernels.cu
|
@ -3088,7 +3088,7 @@ template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_l
|
|||
}
|
||||
}
|
||||
|
||||
#define WARPS 5
|
||||
#define WARPS 3
|
||||
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
|
||||
{
|
||||
|
||||
|
@ -3297,33 +3297,58 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
|||
#endif
|
||||
}
|
||||
|
||||
|
||||
template <typename T> __device__ void printnonzero(T *A, int num_values, const char * strval)
|
||||
{
|
||||
for(int i = 0; i < num_values; i++)
|
||||
if((float)A[i] != 0.0)
|
||||
printf("%s %i %f\n", strval, i, (float)A[i]);
|
||||
}
|
||||
|
||||
template __device__ void printnonzero<float>(float *A, int num_values, const char*strval);
|
||||
template __device__ void printnonzero<half>(half *A, int num_values, const char*strval);
|
||||
|
||||
__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0};
|
||||
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{
|
||||
|
||||
#if __CUDA_ARCH__ >= 750
|
||||
using namespace nvcuda;
|
||||
int col_offset = blockIdx.x *32;
|
||||
const int warp_id = threadIdx.x / 32;
|
||||
const int warp_idx = threadIdx.x % 32;
|
||||
const int half_warp_id = threadIdx.x / 16;
|
||||
const int half_warp_lane = threadIdx.x % 16;
|
||||
const int batch_size_warps = (WARPS-1)*2;
|
||||
|
||||
T quant_map[16];
|
||||
|
||||
#pragma unroll 16
|
||||
for(int i = 0; i < 16; i++)
|
||||
quant_map[i] = nf4_data[i];
|
||||
//__shared__ T quant_map[16*160];
|
||||
|
||||
T local_A[2];
|
||||
T local_B[64];
|
||||
unsigned char local_B_4bit[32];
|
||||
|
||||
|
||||
const int a_tile_offset = 16;
|
||||
const int b_tile_offset = (16*32 + 16);
|
||||
|
||||
__shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))];
|
||||
__shared__ T smem_A[8*16 + (16*(batch_size_warps-1))];
|
||||
__shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))];
|
||||
//__shared__ T smem_C[8*32];
|
||||
__shared__ T smem_C[8*32];
|
||||
|
||||
wmma::fragment<wmma::matrix_a, 8, 32, 16, half, wmma::row_major> a_frag;
|
||||
wmma::fragment<wmma::matrix_b, 8, 32, 16, half, wmma::col_major> b_frag;
|
||||
wmma::fragment<wmma::accumulator, 8, 32, 16, half> c_frag;
|
||||
wmma::fill_fragment(c_frag, 0.0f);
|
||||
|
||||
for(int i = threadIdx.x; i < (8*32); i+=blockDim.x)
|
||||
smem_C[i] = 0.0f;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int ticktock = 0;
|
||||
int idx = 0 + threadIdx.x;
|
||||
int loaded_values = 0;
|
||||
|
@ -3349,8 +3374,17 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
|
|||
#pragma unroll 64
|
||||
for(int col = 0; col < 64; col+=2)
|
||||
{
|
||||
local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f);
|
||||
local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f);
|
||||
//local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f);
|
||||
//local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f);
|
||||
//local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0);
|
||||
//local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0);
|
||||
//local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0);
|
||||
//local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0);
|
||||
|
||||
//local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0);
|
||||
//local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0);
|
||||
local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(17.0);
|
||||
local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(17.0);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3374,13 +3408,17 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
|
|||
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
|
||||
}
|
||||
ticktock = ticktock == 0 ? 1 : 0;
|
||||
//if(threadIdx.x == 0)
|
||||
//printf("aa %i %i\n", idx, loaded_values);
|
||||
|
||||
//for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
|
||||
for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
|
||||
{
|
||||
idx = base_idx + threadIdx.x;
|
||||
//if(threadIdx.x == 0)
|
||||
//printf("%i %i\n", idx, loaded_values);
|
||||
|
||||
__syncthreads();
|
||||
//__syncthreads();
|
||||
if(idx < K && warp_id < (WARPS-1))
|
||||
{
|
||||
if(loaded_values == 0)
|
||||
|
@ -3408,9 +3446,17 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
|
|||
#pragma unroll 64
|
||||
for(int col = 0; col < 64; col+=2)
|
||||
{
|
||||
local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx);
|
||||
local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx);
|
||||
//local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx);
|
||||
//local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx);
|
||||
//local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx);
|
||||
//local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx);
|
||||
|
||||
//local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax);
|
||||
//local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax);
|
||||
local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx);
|
||||
local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx);
|
||||
}
|
||||
//printnonzero<T>(local_B, 128, "");
|
||||
}
|
||||
|
||||
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
|
||||
|
@ -3444,6 +3490,11 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
|
|||
}
|
||||
|
||||
__syncthreads();
|
||||
//if(threadIdx.x == 0)
|
||||
//{
|
||||
// printnonzero<T>(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: ");
|
||||
// printnonzero<T>(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: ");
|
||||
//}
|
||||
if(warp_id != (WARPS-1)){ return; }
|
||||
// only warp_id == (WARPS-1) from here
|
||||
int warp_lane = threadIdx.x % 32;
|
||||
|
@ -3451,6 +3502,8 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
|
|||
ticktock = ticktock == 0 ? 1 : 0;
|
||||
for(int k = 0; k < batch_size_warps; k++)
|
||||
{
|
||||
//if(warp_lane == 0)
|
||||
//printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x);
|
||||
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
|
||||
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
|
||||
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
|
||||
|
@ -3458,13 +3511,116 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
|
|||
|
||||
// 129 mu
|
||||
if(warp_id == (WARPS-1))
|
||||
wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major);
|
||||
wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major);
|
||||
|
||||
//printnonzero<T>(smem_C, 32, "");
|
||||
|
||||
if(col_offset + warp_lane < M)
|
||||
out[col_offset + warp_lane] = smem_A[warp_lane];
|
||||
#endif
|
||||
out[col_offset + warp_lane] = smem_C[warp_lane];
|
||||
}
|
||||
|
||||
#define num_values_4bit 32
|
||||
template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{
|
||||
|
||||
// per threadblock:
|
||||
// load step-by-step in chunks of [64,warps]: 1x64 * [64,warps] -> [1,warps]
|
||||
// 64 packed 8bit values per warp AND each warp loads one output for B -> 1x128 * 128x1
|
||||
// 4 warps -> 4 loads per iter
|
||||
// 1x128 * 128x4 -> 1x4 outputs
|
||||
typedef cub::WarpReduce<float> WarpReduce;
|
||||
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32];
|
||||
|
||||
const int warp_idx = threadIdx.x / 32;
|
||||
const int warp_lane = threadIdx.x % 32;
|
||||
const int row_B = (THREADS/32)*blockIdx.x + warp_idx;
|
||||
const int num_values_8bit = num_values_4bit/2;
|
||||
float local_C = 0.0f;
|
||||
|
||||
unsigned char local_B_4bit[num_values_8bit];
|
||||
T local_B[num_values_4bit];
|
||||
T local_A[num_values_4bit];
|
||||
__shared__ T quant_map[16];
|
||||
T local_absmax = T(0.0f);
|
||||
|
||||
for(int i = threadIdx.x; i < 16; i++)
|
||||
quant_map[i] = datatype[i];
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// A: [1, K]
|
||||
// B: [N, K]
|
||||
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit)
|
||||
{
|
||||
int inner_idx_halved = inner_idx/2;
|
||||
int offset_B = ldb*row_B;
|
||||
int absidx = ((2*offset_B)+inner_idx)/blocksize;
|
||||
local_absmax = __ldg(&(absmax[absidx]));
|
||||
|
||||
if(row_B < M)
|
||||
{
|
||||
if((inner_idx_halved + num_values_8bit) < K)
|
||||
{
|
||||
reinterpret_cast<int4(&)[num_values_8bit]>(local_B_4bit)[0] = reinterpret_cast<int4*>(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)];
|
||||
}
|
||||
else
|
||||
{
|
||||
#pragma unroll
|
||||
for(int j = 0; j < (num_values_8bit); j++)
|
||||
if((inner_idx_halved) + j < K)
|
||||
local_B_4bit[j] = B[offset_B+inner_idx_halved + j];
|
||||
else
|
||||
local_B_4bit[j] = 0b01110111;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(int k = 0; k < num_values_4bit; k++)
|
||||
{
|
||||
local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax;
|
||||
local_B[k*2 + 1] = quant_map[local_B_4bit[k] & 0x0F]*local_absmax;
|
||||
}
|
||||
|
||||
if(inner_idx+num_values_4bit)
|
||||
{
|
||||
if(BITS==16)
|
||||
{
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/4) + 0];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[1] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/4) + 1];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[2] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/4) + 2];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[3] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/4) + 3];
|
||||
}
|
||||
else
|
||||
{
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/8) + 0];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[1] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/8) + 1];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[2] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/8) + 2];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[3] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/8) + 3];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[4] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/8) + 4];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[5] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/8) + 5];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[6] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/8) + 6];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[7] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/8) + 7];
|
||||
}
|
||||
|
||||
}
|
||||
else
|
||||
for(int k = 0; k < num_values_4bit; k++)
|
||||
local_A[k] = A[inner_idx +k];
|
||||
|
||||
#pragma unroll
|
||||
for(int k = 0; k < num_values_4bit; k++)
|
||||
local_C += (float)(local_A[k]*local_B[k]);
|
||||
|
||||
}
|
||||
|
||||
local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C);
|
||||
|
||||
if(row_B < M && warp_lane == 0)
|
||||
out[row_B] = T(local_C);
|
||||
|
||||
}
|
||||
|
||||
|
||||
//#define ROWS 2
|
||||
//template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc)
|
||||
//{
|
||||
|
@ -3627,8 +3783,14 @@ template __global__ void gemm_device<half, 16, 32>(int M, int N, int K, half * _
|
|||
template __global__ void gemm_device<half, 16, 64>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||
template __global__ void gemm_device<half, 16, 96>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||
|
||||
template __global__ void kgemm_4bit_inference<half, 96>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference<half, 256>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
template __global__ void kgemm_4bit_inference_naive<half, 128, 16>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128, 16>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference_naive<float, 128, 32>(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
||||
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
||||
|
@ -3784,6 +3946,20 @@ MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit)
|
|||
MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit)
|
||||
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit)
|
||||
|
@ -3792,13 +3968,6 @@ MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit)
|
|||
MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4)
|
||||
|
@ -3806,13 +3975,6 @@ MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4)
|
|||
MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4)
|
||||
|
@ -3821,12 +3983,38 @@ MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4)
|
|||
MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4)
|
||||
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 1, General8bit)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, NF4)
|
||||
|
||||
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n);
|
||||
|
||||
#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \
|
||||
template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \
|
||||
|
|
|
@ -106,6 +106,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ voi
|
|||
|
||||
template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n);
|
||||
|
||||
|
||||
__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n);
|
||||
|
||||
|
||||
|
@ -124,6 +125,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
|
|||
|
||||
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc);
|
||||
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
template <typename T, int FUNC> __global__ void kfunc(T *A, T *B, T value, long n);
|
||||
|
||||
|
|
36
csrc/ops.cu
36
csrc/ops.cu
|
@ -723,10 +723,20 @@ template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsi
|
|||
//cout << m << endl;
|
||||
//cout << n << endl;
|
||||
//cout << k << endl;
|
||||
kgemm_4bit_inference<T, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
kgemm_4bit_inference<T, 96><<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
//kgemm_4bit_inference<T, 256><<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
//kgemm_4bit_inference<T, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
}
|
||||
|
||||
template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{
|
||||
|
||||
int num_blocks = (m+3)/4;
|
||||
|
||||
kgemm_4bit_inference_naive<T, 128, BITS><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
|
||||
}
|
||||
|
||||
template <typename T, int FUNC> void func(T *A, T *B, T value, long n)
|
||||
{
|
||||
int threads = 512;
|
||||
|
@ -747,6 +757,10 @@ template void func<float, ARANGE>(float *A, float *B, float value, long n);
|
|||
template void func<float, _MUL>(float *A, float *B, float value, long n);
|
||||
|
||||
template void gemm_4bit_inference<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template void gemm_4bit_inference_naive<half, 16>(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template void gemm_4bit_inference_naive<float, 32>(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
|
||||
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
|
||||
template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
|
||||
|
@ -773,19 +787,27 @@ template void estimateQuantiles(half *A, float *code, float offset, int n);
|
|||
template void estimateQuantiles(float *A, float *code, float offset, int n);
|
||||
|
||||
template void quantizeBlockwise<half, 1, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 1, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<half, 0, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 0, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<half, 0, FP4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 0, FP4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<half, 0, NF4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 1, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 0, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 0, FP4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 0, NF4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half, General8bit>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void quantizeBlockwise<__nv_bfloat16, 1, General8bit>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<__nv_bfloat16, 0, General8bit>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<__nv_bfloat16, 0, FP4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<__nv_bfloat16, 0, NF4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
|
||||
template void dequantizeBlockwise<float, General8bit>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half, FP4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<float, FP4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half, NF4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<float, NF4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half, General8bit>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half, FP4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half, NF4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<__nv_bfloat16, General8bit>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<__nv_bfloat16, FP4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<__nv_bfloat16, NF4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n);
|
||||
|
||||
#define MAKE_optimizer32bit(name, gtype) \
|
||||
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
|
||||
|
|
|
@ -200,6 +200,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows
|
|||
|
||||
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits);
|
||||
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
template <typename T, int FUNC> void func(T *A, T *B, T value, long n);
|
||||
|
||||
|
|
|
@ -28,6 +28,15 @@ void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int l
|
|||
void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive<half, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive<__nv_bfloat16, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
void gemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive<float, 32>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
|
||||
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \
|
||||
|
||||
|
@ -103,19 +112,29 @@ void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){
|
|||
void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping<half>(g, gnorm_vec, step, n); }
|
||||
|
||||
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
|
||||
void quantizeBlockwise_bf16(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_bf16_fp4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_bf16_nf4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
|
||||
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
|
||||
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, General8bit>(code, A, absmax, out, blocksize, n); } \
|
||||
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, General8bit>(code, A, absmax, out, blocksize, n); }
|
||||
void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, FP4>(NULL, A, absmax, out, blocksize, n); } \
|
||||
void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, FP4>(NULL, A, absmax, out, blocksize, n); }
|
||||
void dequantizeBlockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, NF4>(NULL, A, absmax, out, blocksize, n); } \
|
||||
|
||||
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, General8bit>(code, A, absmax, out, blocksize, n); }
|
||||
void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, FP4>(NULL, A, absmax, out, blocksize, n); }
|
||||
void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, NF4>(NULL, A, absmax, out, blocksize, n); }
|
||||
|
||||
void dequantizeBlockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise<__nv_bfloat16, General8bit>(code, A, absmax, out, blocksize, n); }
|
||||
void dequantizeBlockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n); }
|
||||
void dequantizeBlockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n); }
|
||||
|
||||
|
||||
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
|
||||
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
|
||||
|
@ -174,21 +193,31 @@ extern "C"
|
|||
void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); }
|
||||
void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); }
|
||||
void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); }
|
||||
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
||||
|
||||
void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
||||
|
||||
void cquantize_blockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); }
|
||||
|
||||
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); }
|
||||
|
||||
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); }
|
||||
|
||||
void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); }
|
||||
|
||||
void cquantize_blockwise_bf16(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_bf16_fp4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_bf16_nf4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); }
|
||||
|
||||
void cdequantize_blockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); }
|
||||
|
||||
#define MAKE_CFUNC32(name, gtype, gbits) \
|
||||
void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \
|
||||
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
|
||||
|
@ -368,6 +397,15 @@ extern "C"
|
|||
CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE)
|
||||
CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)
|
||||
|
||||
void cgemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
void cgemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
#endif
|
||||
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); }
|
||||
|
|
|
@ -154,34 +154,36 @@ def test_dynamic_quantization():
|
|||
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
|
||||
@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"])
|
||||
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
|
||||
def test_dynamic_blockwise_quantization(nested, blocksize):
|
||||
def test_dynamic_blockwise_quantization(dtype, nested, blocksize):
|
||||
#print('')
|
||||
diffs = []
|
||||
reldiffs = []
|
||||
for i in range(100):
|
||||
A1 = torch.randn(1024, 1024, device="cuda")
|
||||
A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
|
||||
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
|
||||
A2 = F.dequantize_blockwise(C, S)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff / torch.abs(A1 + 1e-8)
|
||||
diff = torch.abs(A1 - A2).float()
|
||||
reldiff = diff / torch.abs(A1.float() + 1e-8)
|
||||
diffs.append(diff.mean().item())
|
||||
reldiffs.append(reldiff.mean().item())
|
||||
abserr = sum(diffs)/len(diffs)
|
||||
relerr = sum(reldiffs)/len(reldiffs)
|
||||
#print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
|
||||
#print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
|
||||
assert abserr < 0.011
|
||||
assert relerr < 0.018
|
||||
#print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs))
|
||||
#print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
assert A2.dtype == dtype
|
||||
|
||||
diffs = []
|
||||
for i in range(100):
|
||||
A1 = torch.rand(1024, 1024, device="cuda")
|
||||
A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype)
|
||||
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
|
||||
A2 = F.dequantize_blockwise(C, S)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff / torch.abs(A1 + 1e-8)
|
||||
diff = torch.abs(A1 - A2).float()
|
||||
reldiff = diff / torch.abs(A1.float() + 1e-8)
|
||||
diffs.append(diff.mean().item())
|
||||
reldiffs.append(reldiff.mean().item())
|
||||
#torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
|
||||
|
@ -189,6 +191,7 @@ def test_dynamic_blockwise_quantization(nested, blocksize):
|
|||
relerr = sum(reldiffs)/len(reldiffs)
|
||||
assert abserr < 0.0035
|
||||
assert relerr < 0.015
|
||||
assert A2.dtype == dtype
|
||||
#print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
|
||||
#print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
|
||||
|
@ -1781,16 +1784,16 @@ values = []
|
|||
#values.append((batch_size, seqdim, 1536, 4*1536))
|
||||
#values.append((batch_size, seqdim, 2048, 4*2048))
|
||||
#values.append((batch_size, seqdim, 2560, 4*2560))
|
||||
values.append((batch_size, seqdim, 4096, 4*4096))
|
||||
values.append((batch_size, seqdim, 5120, 4*5120))
|
||||
#values.append((batch_size, seqdim, 4096, 4*4096))
|
||||
#values.append((batch_size, seqdim, 5120, 4*5120))
|
||||
values.append((batch_size, seqdim, 6656, 4*6656))
|
||||
values.append((batch_size, seqdim, 8192, 4*8192))
|
||||
#values.append((batch_size, seqdim, 8192, 4*8192))
|
||||
#values.append((batch_size, seqdim, 5140, 4*5140))
|
||||
#values.append((batch_size, seqdim, 12288, 4*12288))
|
||||
names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values]
|
||||
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
|
||||
def test_bench_matmul(batch, seq, model, hidden):
|
||||
iters = 80
|
||||
iters = 1000
|
||||
formatB = F.get_special_format_str()
|
||||
|
||||
A = torch.randn(batch, seq, model, device="cuda").half()
|
||||
|
@ -1800,7 +1803,8 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
B_fp4, state = F.quantize_fp4(B)
|
||||
B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True)
|
||||
|
||||
B_nf4, state_nf4= F.quantize_nf4(B)
|
||||
B_nf4, state_nf4 = F.quantize_nf4(B)
|
||||
B_nf4_c, state_nf4_c = F.quantize_nf4(B, compress_statistics=True)
|
||||
|
||||
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half()
|
||||
linear8bit.eval()
|
||||
|
@ -1813,6 +1817,7 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
|
||||
linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
|
||||
linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
|
||||
bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
|
||||
|
||||
# warmup
|
||||
for i in range(iters):
|
||||
|
@ -1827,19 +1832,19 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
torch.cuda.synchronize()
|
||||
print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
|
||||
torch.cuda.synchronize()
|
||||
print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
|
||||
#torch.cuda.synchronize()
|
||||
#print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
|
||||
torch.cuda.synchronize()
|
||||
print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
|
||||
#torch.cuda.synchronize()
|
||||
#print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
|
@ -1848,6 +1853,14 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
torch.cuda.synchronize()
|
||||
print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c)
|
||||
torch.cuda.synchronize()
|
||||
print( f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
|
||||
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
|
@ -1901,21 +1914,21 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
#torch.cuda.synchronize()
|
||||
#print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
linear8bit(A)
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
linear8bit(A)
|
||||
torch.cuda.synchronize()
|
||||
print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
#linear8bit(A)
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# linear8bit(A)
|
||||
#torch.cuda.synchronize()
|
||||
#print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
linearMixedBit(A)
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
linearMixedBit(A)
|
||||
torch.cuda.synchronize()
|
||||
print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
#linearMixedBit(A)
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# linearMixedBit(A)
|
||||
#torch.cuda.synchronize()
|
||||
#print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
#linear8bit_train(A)
|
||||
#torch.cuda.synchronize()
|
||||
|
@ -2221,7 +2234,8 @@ def test_bench_dequantization():
|
|||
|
||||
|
||||
|
||||
def test_fp4_quant():
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
|
||||
def test_fp4_quant(dtype):
|
||||
vals = list(product([0, 1], repeat=4))
|
||||
|
||||
code = {}
|
||||
|
@ -2243,7 +2257,7 @@ def test_fp4_quant():
|
|||
result = sign*exp*frac
|
||||
code[idx] = result
|
||||
|
||||
A1 = torch.randn(1024, 1024, device='cuda').half()
|
||||
A1 = torch.randn(1024, 1024, device='cuda', dtype=dtype)
|
||||
qa, SA = F.quantize_fp4(A1, blocksize=64)
|
||||
A2 = F.dequantize_fp4(qa, SA)
|
||||
|
||||
|
@ -2252,7 +2266,7 @@ def test_fp4_quant():
|
|||
idx = err > 1.0
|
||||
err = err.mean()
|
||||
|
||||
|
||||
assert A2.dtype == dtype
|
||||
assert err.item() < 0.1
|
||||
assert relerr.item() < 0.28
|
||||
|
||||
|
@ -2297,7 +2311,8 @@ def test_4bit_compressed_stats(quant_type):
|
|||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
||||
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
|
||||
#@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
|
||||
@pytest.mark.parametrize("quant_type", ['nf4'])
|
||||
def test_bench_4bit_dequant(quant_type):
|
||||
blocksize = 256
|
||||
a = torch.rand(1024*12*4, 1024*12, device='cuda').half()
|
||||
|
@ -2311,7 +2326,7 @@ def test_bench_4bit_dequant(quant_type):
|
|||
#print(max_theoretical_s*1e6)
|
||||
b = torch.randn(128, 1024*12, device='cuda').half()
|
||||
|
||||
iters = 5
|
||||
iters = 100
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
|
@ -2344,139 +2359,88 @@ def test_normal_map_tree():
|
|||
print(pivots)
|
||||
|
||||
|
||||
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
||||
def test_cutlass3_gemm(dtype):
|
||||
debug = True
|
||||
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
||||
#for dim in [4096, 5120, 6656, 8192]:
|
||||
for dim in [4096]:
|
||||
#for dim in [128+1]:
|
||||
@pytest.mark.parametrize("double_quant", [True, False], ids=['DQ_True', 'DQ_False'])
|
||||
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
|
||||
def test_gemv_4bit(dtype, storage_type, double_quant):
|
||||
print('')
|
||||
for dim in [128, 256, 512, 1024, 2048, 4096]:
|
||||
#for dim in [4*1024]:
|
||||
#for dim in [1*16]:
|
||||
errs = []
|
||||
relerrs = []
|
||||
max_err = 0
|
||||
max_relerr = 0
|
||||
|
||||
for i in range(100):
|
||||
A = torch.randn(1, dim, dtype=dtype, device='cuda')
|
||||
B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||
#B = torch.randn(1, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||
|
||||
#print('')
|
||||
#print(A)
|
||||
#print(B.t())
|
||||
#A[:, :-1] = 0
|
||||
#B[:, :-1] = 0
|
||||
|
||||
|
||||
C1 = torch.matmul(A, B.t())
|
||||
C2 = F.cutlass3_gemm(A, B.t())
|
||||
|
||||
# tensor cores are non-deterministic
|
||||
# so we need to analyze errors around the mean
|
||||
# to test our implementation
|
||||
err = torch.abs(C1-C2)
|
||||
mag = torch.abs(C1)+1e-8
|
||||
relerr = err/mag
|
||||
max_err = max(err.max(), max_err)
|
||||
max_relerr = max(relerr.max(), max_relerr)
|
||||
err = err.mean().item()
|
||||
relerr = relerr.mean().item()
|
||||
|
||||
errs.append(err)
|
||||
relerrs.append(relerr)
|
||||
|
||||
#if not debug and err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
|
||||
# print('')
|
||||
# print(i, err, relerr)
|
||||
# print(A.flatten()[-6:])
|
||||
# print(B.flatten()[-6:])
|
||||
# out = A.flatten()[-6:]*B.flatten()[-6:]
|
||||
# print(out)
|
||||
# print(out[:-1].sum())
|
||||
# print('='*80)
|
||||
# print(C1.flatten()[-6:])
|
||||
# print(C2.flatten()[-6:])
|
||||
# #assert False, 'ERROR'
|
||||
|
||||
c = int(C1.numel()*0.0014*(dim/256))+1
|
||||
|
||||
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=not debug)
|
||||
#print(c/math.sqrt(dim))
|
||||
print('')
|
||||
print(dim, sum(errs)/len(errs)/math.sqrt(dim))
|
||||
print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
|
||||
print(dim, (max_err.item(), max_relerr.item()))
|
||||
|
||||
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
||||
def test_gemm_4bit(dtype):
|
||||
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
||||
#for dim in [4096, 5120, 6656, 8192]:
|
||||
#for dim in [32]:
|
||||
for dim in [4096]:
|
||||
errs = []
|
||||
relerrs = []
|
||||
max_err = 0
|
||||
max_relerr = 0
|
||||
for i in range(1):
|
||||
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
|
||||
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
|
||||
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
|
||||
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
|
||||
A = torch.randn(1, dim+0, dtype=dtype, device='cuda')
|
||||
B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||
A = torch.randn(1, dim, dtype=dtype, device='cuda')
|
||||
B = torch.randn(4*dim, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||
#B = torch.randn(1, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||
|
||||
#print('')
|
||||
#print(A)
|
||||
#print(B.t())
|
||||
#A[:, :-1] = 0
|
||||
#B[:, :-1] = 0
|
||||
#A.flatten()[:-1] = 0
|
||||
#B.flatten()[:-1] = 0
|
||||
|
||||
qB, state = F.quantize_nf4(B)
|
||||
F.dequantize_nf4(qB, state)
|
||||
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
|
||||
#F.dequantize_4bit(qB, state)
|
||||
|
||||
C3 = torch.matmul(A, B.t())
|
||||
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
|
||||
C2 = F.gemv_4bit(A, qB.t(), state=state)
|
||||
A.requires_grad = True
|
||||
C1 = bnb.matmul_4bit(A, qB.t(), state)
|
||||
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
|
||||
|
||||
print(C1.shape, C2.shape)
|
||||
#print(state)
|
||||
#print(qB)
|
||||
|
||||
#print('')
|
||||
#print(A)
|
||||
#print(B)
|
||||
#print('='*89)
|
||||
#print(C3)
|
||||
|
||||
#print(C1.shape, C2.shape)
|
||||
|
||||
# tensor cores are non-deterministic
|
||||
# so we need to analyze errors around the mean
|
||||
# to test our implementation
|
||||
err = torch.abs(C1-C2)
|
||||
mag = torch.abs(C1)+1e-8
|
||||
err = torch.abs(C1-C2).float()
|
||||
mag = torch.abs(C1).float()+1e-5
|
||||
relerr = err/mag
|
||||
max_err = max(err.max(), max_err)
|
||||
max_relerr = max(relerr.max(), max_relerr)
|
||||
err = err.mean().item()
|
||||
relerr = relerr.mean().item()
|
||||
#print(err)
|
||||
|
||||
errs.append(err)
|
||||
relerrs.append(relerr)
|
||||
|
||||
if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
|
||||
print('')
|
||||
print(i, err, relerr)
|
||||
print(A.flatten()[-6:])
|
||||
print(B.flatten()[-6:])
|
||||
out = A.flatten()[-6:]*B.flatten()[-6:]
|
||||
print(out)
|
||||
print(out[:-1].sum())
|
||||
print('='*80)
|
||||
print(C1.flatten()[-6:])
|
||||
print(C2.flatten()[-6:])
|
||||
#assert False, 'ERROR'
|
||||
|
||||
c = int(C1.numel()*0.0014*(dim/256))+1
|
||||
|
||||
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
|
||||
#print(c/math.sqrt(dim))
|
||||
print('')
|
||||
print(dim, sum(errs)/len(errs)/math.sqrt(dim))
|
||||
print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
|
||||
print(dim, (max_err.item(), max_relerr.item()))
|
||||
#print('')
|
||||
#print(dim, sum(errs)/len(errs)/math.sqrt(dim))
|
||||
#print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
|
||||
#print(dim, (max_err.item(), max_relerr.item()))
|
||||
print(C1.flatten()[-20:])
|
||||
print(C2.flatten()[-20:])
|
||||
print(C3.flatten()[-20:])
|
||||
print(sum(errs)/len(errs)/math.sqrt(dim) , dim)
|
||||
print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , dim)
|
||||
if dtype == torch.float16:
|
||||
assert sum(errs)/len(errs)/math.sqrt(dim) < 5e-5
|
||||
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.0005
|
||||
else:
|
||||
assert sum(errs)/len(errs)/math.sqrt(dim) < 3e-4
|
||||
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.003
|
||||
|
||||
@pytest.mark.skip("Row scale has some bugs for ampere")
|
||||
def test_managed():
|
||||
|
|
|
@ -535,6 +535,7 @@ def test_kbit_backprop(module):
|
|||
kbit[1].bias.detach().copy_(ref[1].bias)
|
||||
ref = ref.half().cuda()
|
||||
kbit = kbit.half().cuda()
|
||||
kbit = kbit.half().to('cuda')
|
||||
|
||||
errs1 = []
|
||||
errs2 = []
|
||||
|
|
Loading…
Reference in New Issue
Block a user