Merge branch 'main' into cleanup
This commit is contained in:
commit
b104ce3b62
|
@ -52,8 +52,13 @@ class CUDASetup:
|
|||
self.add_log_entry('python setup.py install')
|
||||
|
||||
def initialize(self):
|
||||
self.cuda_setup_log = []
|
||||
self.has_printed = False
|
||||
self.lib = None
|
||||
self.run_cuda_setup()
|
||||
|
||||
def run_cuda_setup(self):
|
||||
self.initialized = True
|
||||
self.cuda_setup_log = []
|
||||
|
||||
from .cuda_setup.main import evaluate_cuda_setup
|
||||
binary_name, cudart_path, cuda, cc, cuda_version_string = evaluate_cuda_setup()
|
||||
|
@ -89,7 +94,8 @@ class CUDASetup:
|
|||
else:
|
||||
self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...")
|
||||
self.lib = ct.cdll.LoadLibrary(binary_path)
|
||||
except:
|
||||
except Exception as ex:
|
||||
self.add_log_entry(str(ex))
|
||||
self.print_log_stack()
|
||||
|
||||
def add_log_entry(self, msg, is_warning=False):
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import ctypes as ct
|
||||
import itertools
|
||||
import operator
|
||||
import random
|
||||
from functools import reduce # Required in Python 3
|
||||
|
@ -130,13 +131,59 @@ class Cusparse_Context:
|
|||
return cls._instance
|
||||
|
||||
|
||||
def create_linear_map(signed=True):
|
||||
if signed:
|
||||
return torch.linspace(-1.0, 1.0, 256)
|
||||
return torch.linspace(0.0, 1.0, 256)
|
||||
def create_linear_map(signed=True, total_bits=8):
|
||||
sign = (-1.0 if signed else 0.0)
|
||||
|
||||
values = torch.linspace(sign, 1.0, 2**total_bits)
|
||||
gap = 256 - values.numel()
|
||||
if gap == 0:
|
||||
return values
|
||||
else:
|
||||
l = values.numel()//2
|
||||
#return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist())
|
||||
return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist())
|
||||
|
||||
|
||||
def create_dynamic_map(signed=True, n=7):
|
||||
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8):
|
||||
e = exponent_bits
|
||||
p = precision_bits
|
||||
has_sign = 1 if signed else 0
|
||||
assert e+p == total_bits-has_sign
|
||||
# the exponent is biased to 2^(e-1) -1 == 0
|
||||
evalues = []
|
||||
pvalues = []
|
||||
for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)):
|
||||
evalues.append(2**val)
|
||||
|
||||
|
||||
lst = list(itertools.product([0, 1], repeat=precision_bits))
|
||||
for bit_pattern in lst:
|
||||
value = 1
|
||||
for i, pval in enumerate(list(bit_pattern)):
|
||||
value += pval*(2**-(i+1))
|
||||
pvalues.append(value)
|
||||
|
||||
assert len(evalues)*len(pvalues) == 2**(total_bits-has_sign)
|
||||
values = []
|
||||
for ev in evalues:
|
||||
for pv in pvalues:
|
||||
if signed:
|
||||
values.append(-ev*pv)
|
||||
values.append(ev*pv)
|
||||
if total_bits < 8:
|
||||
gap = 256 - len(values)
|
||||
for i in range(gap):
|
||||
values.append(0)
|
||||
values.sort()
|
||||
code = torch.Tensor(values)
|
||||
code /= code.max()
|
||||
code[127] = 0
|
||||
|
||||
return code
|
||||
|
||||
|
||||
|
||||
def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
|
||||
"""
|
||||
Creates the dynamic quantiztion map.
|
||||
|
||||
|
@ -157,28 +204,32 @@ def create_dynamic_map(signed=True, n=7):
|
|||
# these are additional items that come from the case
|
||||
# where all the exponent bits are zero and no
|
||||
# indicator bit is present
|
||||
additional_items = 2 ** (7 - n) - 1
|
||||
non_sign_bits = total_bits - (1 if signed else 0)
|
||||
additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1
|
||||
if not signed:
|
||||
additional_items = 2 * additional_items
|
||||
for i in range(n):
|
||||
fraction_items = (
|
||||
2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1
|
||||
)
|
||||
for i in range(max_exponent_bits):
|
||||
fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1))
|
||||
boundaries = torch.linspace(0.1, 1, fraction_items)
|
||||
means = (boundaries[:-1] + boundaries[1:]) / 2.0
|
||||
data += ((10 ** (-(n - 1) + i)) * means).tolist()
|
||||
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
|
||||
if signed:
|
||||
data += (-(10 ** (-(n - 1) + i)) * means).tolist()
|
||||
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
|
||||
|
||||
if additional_items > 0:
|
||||
boundaries = torch.linspace(0.1, 1, additional_items + 1)
|
||||
means = (boundaries[:-1] + boundaries[1:]) / 2.0
|
||||
data += ((10 ** (-(n - 1) + i)) * means).tolist()
|
||||
if signed:
|
||||
data += (-(10 ** (-(n - 1) + i)) * means).tolist()
|
||||
if additional_items > 0:
|
||||
boundaries = torch.linspace(0.1, 1, additional_items + 1)
|
||||
means = (boundaries[:-1] + boundaries[1:]) / 2.0
|
||||
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
|
||||
if signed:
|
||||
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
|
||||
|
||||
data.append(0)
|
||||
data.append(1.0)
|
||||
|
||||
gap = 256 - len(data)
|
||||
for i in range(gap):
|
||||
data.append(0)
|
||||
|
||||
data.sort()
|
||||
return Tensor(data)
|
||||
|
||||
|
@ -322,9 +373,7 @@ def nvidia_transform(
|
|||
return out, new_state
|
||||
|
||||
|
||||
def estimate_quantiles(
|
||||
A: Tensor, out: Tensor = None, offset: float = 1 / 512
|
||||
) -> Tensor:
|
||||
def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor:
|
||||
'''
|
||||
Estimates 256 equidistant quantiles on the input tensor eCDF.
|
||||
|
||||
|
@ -344,25 +393,36 @@ def estimate_quantiles(
|
|||
out : torch.Tensor
|
||||
Tensor with the 256 estimated quantiles.
|
||||
offset : float
|
||||
The offset for the first and last quantile from 0 and 1. Default: 1/512
|
||||
The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles)
|
||||
num_quantiles : int
|
||||
The number of equally spaced quantiles.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor:
|
||||
The 256 quantiles in float32 datatype.
|
||||
'''
|
||||
if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.')
|
||||
if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}")
|
||||
if num_quantiles < 256 and offset == 1/(512):
|
||||
# override default arguments
|
||||
offset = 1/(2*num_quantiles)
|
||||
|
||||
if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
|
||||
is_on_gpu([A, out])
|
||||
device = pre_call(A.device)
|
||||
if A.dtype == torch.float32:
|
||||
lib.cestimate_quantiles_fp32(
|
||||
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
|
||||
)
|
||||
lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
|
||||
elif A.dtype == torch.float16:
|
||||
lib.cestimate_quantiles_fp16(
|
||||
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
|
||||
)
|
||||
lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
|
||||
else:
|
||||
raise NotImplementedError(f"Not supported data type {A.dtype}")
|
||||
post_call(device)
|
||||
|
||||
if num_quantiles < 256:
|
||||
idx = torch.linspace(0, 255, num_quantiles).long().to(A.device)
|
||||
out = out[idx]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
@ -395,15 +455,14 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
|||
The quantization state to undo the quantization.
|
||||
"""
|
||||
|
||||
|
||||
if code is None:
|
||||
if "dynamic" not in name2qmap:
|
||||
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
|
||||
code = name2qmap["dynamic"]
|
||||
code = code.to(A.device)
|
||||
|
||||
if absmax is None:
|
||||
n = A.numel()
|
||||
blocksize = (blocksize if A.device.type == 'cpu' else 4096)
|
||||
blocks = n // blocksize
|
||||
blocks += 1 if n % blocksize > 0 else 0
|
||||
absmax = torch.zeros((blocks,), device=A.device)
|
||||
|
@ -412,8 +471,13 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
|||
out = torch.zeros_like(A, dtype=torch.uint8)
|
||||
|
||||
if A.device.type != 'cpu':
|
||||
is_on_gpu([code, A, absmax, out, rand])
|
||||
assert blocksize in [4096, 2048, 1024, 512]
|
||||
cblocksize = ct.c_int32(blocksize)
|
||||
prev_device = pre_call(A.device)
|
||||
code = code.to(A.device)
|
||||
if rand is not None:
|
||||
is_on_gpu([code, A, out, absmax, rand])
|
||||
assert blocksize==4096
|
||||
assert rand.numel() >= 1024
|
||||
rand_offset = random.randint(0, 1023)
|
||||
if A.dtype == torch.float32:
|
||||
|
@ -421,20 +485,19 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
|||
elif A.dtype == torch.float16:
|
||||
lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
|
||||
)
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
else:
|
||||
is_on_gpu([code, A, out, absmax])
|
||||
if A.dtype == torch.float32:
|
||||
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
|
||||
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||
elif A.dtype == torch.float16:
|
||||
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
|
||||
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
|
||||
)
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
post_call(A.device)
|
||||
else:
|
||||
# cpu
|
||||
code = code.cpu()
|
||||
assert rand is None
|
||||
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
|
||||
|
||||
|
@ -479,27 +542,30 @@ def dequantize_blockwise(
|
|||
if "dynamic" not in name2qmap:
|
||||
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
|
||||
code = name2qmap["dynamic"]
|
||||
code = code.to(A.device)
|
||||
|
||||
if out is None:
|
||||
out = torch.zeros_like(A, dtype=torch.float32)
|
||||
if quant_state is None:
|
||||
quant_state = (absmax, code)
|
||||
else:
|
||||
absmax, code = quant_state
|
||||
|
||||
|
||||
if A.device.type != 'cpu':
|
||||
if blocksize not in [2048, 4096]:
|
||||
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]")
|
||||
device = pre_call(A.device)
|
||||
code = code.to(A.device)
|
||||
if blocksize not in [2048, 4096, 1024, 512]:
|
||||
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512]")
|
||||
is_on_gpu([A, out])
|
||||
if out.dtype == torch.float32:
|
||||
lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
||||
lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
||||
elif out.dtype == torch.float16:
|
||||
lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
||||
lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
|
||||
)
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
post_call(A.device)
|
||||
else:
|
||||
code = code.cpu()
|
||||
lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
|
||||
|
||||
return out
|
||||
|
|
|
@ -428,16 +428,16 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
|
|||
}
|
||||
|
||||
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC>
|
||||
__launch_bounds__(TH, 4)
|
||||
//__launch_bounds__(TH, 4)
|
||||
__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n)
|
||||
{
|
||||
const int n_full = gridDim.x * BLOCK_SIZE;
|
||||
int valid_items = 0;
|
||||
const int base_idx = (blockIdx.x * BLOCK_SIZE);
|
||||
|
||||
T vals[NUM];
|
||||
float rand_vals[NUM];
|
||||
unsigned char qvals[NUM];
|
||||
T vals[NUM_PER_TH];
|
||||
float rand_vals[NUM_PER_TH];
|
||||
unsigned char qvals[NUM_PER_TH];
|
||||
//float local_abs_max = -FLT_MAX;
|
||||
float local_abs_max = 0.0f;
|
||||
int local_rand_idx = 0;
|
||||
|
@ -510,15 +510,15 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
|
|||
}
|
||||
|
||||
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH>
|
||||
__global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, T *out, const int n)
|
||||
__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n)
|
||||
{
|
||||
|
||||
const int n_full = gridDim.x * BLOCK_SIZE;
|
||||
int valid_items = 0;
|
||||
const int base_idx = (blockIdx.x * BLOCK_SIZE);
|
||||
|
||||
T vals[NUM];
|
||||
unsigned char qvals[NUM];
|
||||
T vals[NUM_PER_TH];
|
||||
unsigned char qvals[NUM_PER_TH];
|
||||
float local_abs_max = -FLT_MAX;
|
||||
|
||||
typedef cub::BlockLoad<unsigned char, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
|
||||
|
@ -526,10 +526,11 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c
|
|||
|
||||
__shared__ typename LoadChar::TempStorage loadchar;
|
||||
__shared__ typename StoreT::TempStorage storet;
|
||||
__shared__ float smem_code[256];
|
||||
//__shared__ float smem_code[256];
|
||||
//float local_code[16];
|
||||
|
||||
if(threadIdx.x < 256)
|
||||
smem_code[threadIdx.x] = code[threadIdx.x];
|
||||
//if(threadIdx.x < 256)
|
||||
//smem_code[threadIdx.x] = code[threadIdx.x];
|
||||
|
||||
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
|
||||
{
|
||||
|
@ -539,9 +540,10 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c
|
|||
__syncthreads();
|
||||
LoadChar(loadchar).Load(&(A[i]), qvals, valid_items, 128);
|
||||
|
||||
// load code through read-only cache via __ldg
|
||||
#pragma unroll NUM_PER_TH
|
||||
for(int j = 0; j < NUM_PER_TH; j++)
|
||||
vals[j] = smem_code[qvals[j]]*local_abs_max;
|
||||
vals[j] = __ldg(&code[qvals[j]])*local_abs_max;
|
||||
|
||||
__syncthreads();
|
||||
StoreT(storet).Store(&(out[i]), vals, valid_items);
|
||||
|
@ -2791,11 +2793,21 @@ template __global__ void kQuantizeBlockwise<half, 4096, 4, 0>(float * code, half
|
|||
template __global__ void kQuantizeBlockwise<float, 4096, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 4096, 4, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 4096, 4, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 2048, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 2048, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 1024, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 1024, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 512, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 512, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
|
||||
template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 2048, 512, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 2048, 512, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 1024, 256, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 1024, 256, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 512, 256, 2>(float *code, unsigned char * A, float * absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 256, 2>(float *code, unsigned char * A, float * absmax, float *out, const int n);
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
|
|||
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n);
|
||||
|
||||
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC> __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, T *out, const int n);
|
||||
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n);
|
||||
|
||||
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
|
||||
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
|
||||
|
|
33
csrc/ops.cu
33
csrc/ops.cu
|
@ -50,11 +50,23 @@ void dequantize(float *code, unsigned char *A, float *out, int n)
|
|||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n)
|
||||
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n)
|
||||
{
|
||||
int num_blocks = n/4096;
|
||||
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
|
||||
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
int num_blocks = n/blocksize;
|
||||
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
|
||||
if(STOCHASTIC == 1)
|
||||
assert(blocksize == 4096);
|
||||
|
||||
if(blocksize == 4096)
|
||||
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 2048)
|
||||
kQuantizeBlockwise<T, 2048, 4, 0><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 1024)
|
||||
kQuantizeBlockwise<T, 1024, 4, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 512)
|
||||
kQuantizeBlockwise<T, 512, 2, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
|
||||
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
@ -66,6 +78,11 @@ template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, flo
|
|||
kDequantizeBlockwise<T, 4096, 1024, 4><<<num_blocks, 4096/4>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 2048)
|
||||
kDequantizeBlockwise<T, 2048, 512, 4><<<num_blocks, 2048/4>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 1024)
|
||||
kDequantizeBlockwise<T, 1024, 256, 4><<<num_blocks, 1024/4>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 512)
|
||||
kDequantizeBlockwise<T, 512, 256, 2><<<num_blocks, 512/2>>>(code, A, absmax, out, n);
|
||||
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
@ -659,10 +676,10 @@ template void transformRowToFormat<COL_AMPERE, 1>(char * A, char *out, int rows,
|
|||
template void estimateQuantiles(half *A, float *code, float offset, int n);
|
||||
template void estimateQuantiles(float *A, float *code, float offset, int n);
|
||||
|
||||
template void quantizeBlockwise<half, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
|
||||
template void quantizeBlockwise<float, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
|
||||
template void quantizeBlockwise<half, 1>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
|
||||
template void quantizeBlockwise<float, 1>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
|
||||
template void quantizeBlockwise<half, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<half, 1>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 1>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<float>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||
|
||||
|
|
|
@ -128,7 +128,7 @@ template <typename T> void estimateQuantiles(T *A, float *code, float offset, in
|
|||
|
||||
void quantize(float *code, float *A, unsigned char *out, int n);
|
||||
void dequantize(float *code, unsigned char *A, float *out, int n);
|
||||
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
|
||||
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n);
|
||||
|
||||
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
||||
|
|
|
@ -75,10 +75,10 @@ MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32)
|
|||
void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
|
||||
void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping<half>(g, gnorm_vec, step, n); }
|
||||
|
||||
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise<half, 0>(code, A, absmax, out, NULL, 0, n); }
|
||||
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise<float, 0>(code, A, absmax, out, NULL, 0, n); }
|
||||
void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<half, 1>(code, A, absmax, out, rand, rand_offset, n); }
|
||||
void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<float, 1>(code, A, absmax, out, rand, rand_offset, n); }
|
||||
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<half, 1>(code, A, absmax, out, rand, rand_offset, 4096, n); }
|
||||
void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<float, 1>(code, A, absmax, out, rand, rand_offset, 4096, n); }
|
||||
|
||||
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half>(code, A, absmax, out, blocksize, n); } \
|
||||
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); }
|
||||
|
@ -140,8 +140,8 @@ extern "C"
|
|||
void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); }
|
||||
void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); }
|
||||
void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); }
|
||||
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, n); }
|
||||
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, n); }
|
||||
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, n); }
|
||||
void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, n); }
|
||||
|
||||
|
|
|
@ -6,12 +6,14 @@ from itertools import product
|
|||
import einops
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes import functional as F
|
||||
from scipy.stats import norm
|
||||
|
||||
torch.set_printoptions(
|
||||
precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
|
||||
precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
|
||||
)
|
||||
k = 20
|
||||
|
||||
|
@ -149,30 +151,41 @@ def test_dynamic_quantization():
|
|||
|
||||
|
||||
def test_dynamic_blockwise_quantization():
|
||||
diffs = []
|
||||
reldiffs = []
|
||||
for i in range(100):
|
||||
A1 = torch.randn(1024, 1024, device="cuda")
|
||||
C, S = F.quantize_blockwise(A1)
|
||||
A2 = F.dequantize_blockwise(C, S)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff / torch.abs(A1 + 1e-8)
|
||||
diffs.append(diff.mean().item())
|
||||
reldiffs.append(reldiff.mean().item())
|
||||
assert diffs[-1] < 0.011
|
||||
# print(sum(diffs)/len(diffs))
|
||||
# print(sum(reldiffs)/len(reldiffs))
|
||||
#print('')
|
||||
for blocksize in [4096, 2048, 1024, 512]:
|
||||
diffs = []
|
||||
reldiffs = []
|
||||
for i in range(100):
|
||||
A1 = torch.randn(1024, 1024, device="cuda")
|
||||
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
|
||||
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff / torch.abs(A1 + 1e-8)
|
||||
diffs.append(diff.mean().item())
|
||||
reldiffs.append(reldiff.mean().item())
|
||||
abserr = sum(diffs)/len(diffs)
|
||||
relerr = sum(reldiffs)/len(reldiffs)
|
||||
assert abserr < 0.011
|
||||
assert relerr < 0.018
|
||||
#print('randn', blocksize, sum(diffs)/len(diffs))
|
||||
#print('randn', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
|
||||
diffs = []
|
||||
for i in range(100):
|
||||
A1 = torch.rand(1024, 1024, device="cuda")
|
||||
C, S = F.quantize_blockwise(A1)
|
||||
A2 = F.dequantize_blockwise(C, S)
|
||||
diff = torch.abs(A1 - A2).mean().item()
|
||||
assert diff < 0.0033
|
||||
diffs.append(diff)
|
||||
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
|
||||
# print(sum(diffs)/len(diffs))
|
||||
diffs = []
|
||||
for i in range(100):
|
||||
A1 = torch.rand(1024, 1024, device="cuda")
|
||||
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
|
||||
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff / torch.abs(A1 + 1e-8)
|
||||
diffs.append(diff.mean().item())
|
||||
reldiffs.append(reldiff.mean().item())
|
||||
#torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
|
||||
abserr = sum(diffs)/len(diffs)
|
||||
relerr = sum(reldiffs)/len(reldiffs)
|
||||
assert abserr < 0.0035
|
||||
assert relerr < 0.015
|
||||
#print('rand', blocksize, sum(diffs)/len(diffs))
|
||||
#print('rand', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
|
||||
|
||||
def test_dynamic_blockwise_stochastic_quantization():
|
||||
|
@ -1616,17 +1629,6 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
|
|||
# print(time.time() - t0)
|
||||
|
||||
|
||||
def test_layout():
|
||||
a1 = torch.rand(16, 64, device="cuda", dtype=torch.float16)
|
||||
a1 = torch.arange(16 * 64, device="cuda").reshape(16, 64).byte()
|
||||
a2, s2 = F.transform(a1, "col_turing")
|
||||
print(a2.shape)
|
||||
|
||||
print(a1.flatten()[8 * 64 : 8 * 64 + 32])
|
||||
for i in range(4):
|
||||
print(a2.flatten()[i * 8 * 32 : i * 8 * 32 + 32], 0)
|
||||
|
||||
|
||||
def test_coo2csr():
|
||||
threshold = 1
|
||||
A = torch.randn(128, 128).half().cuda()
|
||||
|
@ -2040,3 +2042,143 @@ def test_blockwise_cpu_large():
|
|||
assert diffs[-1] < 0.011
|
||||
# print(sum(diffs)/len(diffs))
|
||||
# print(sum(reldiffs)/len(reldiffs))
|
||||
|
||||
|
||||
|
||||
def test_fp8_quant():
|
||||
for e_bits in range(1, 7):
|
||||
p_bits = 7-e_bits
|
||||
code = F.create_fp8_map(True, e_bits, p_bits).cuda()
|
||||
|
||||
print(e_bits, p_bits)
|
||||
abserr = []
|
||||
relerr = []
|
||||
for i in range(100):
|
||||
A1 = torch.randn(1024, 1024, device="cuda")
|
||||
C, SC = F.quantize_blockwise(A1, code=code)
|
||||
A2 = F.dequantize_blockwise(C, SC)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff/torch.abs(A1+1e-8)
|
||||
abserr.append(diff.mean().item())
|
||||
relerr.append(reldiff.mean().item())
|
||||
#assert diff < 0.0075
|
||||
#print(sum(abserr)/len(abserr))
|
||||
#print(sum(relerr)/len(relerr))
|
||||
|
||||
abserr = []
|
||||
relerr = []
|
||||
for i in range(100):
|
||||
A1 = torch.rand(1024, 1024, device="cuda")
|
||||
C, SC = F.quantize_blockwise(A1, code=code)
|
||||
A2 = F.dequantize_blockwise(C, SC)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff/torch.abs(A1+1e-8)
|
||||
abserr.append(diff.mean().item())
|
||||
relerr.append(reldiff.mean().item())
|
||||
#assert diff < 0.0075
|
||||
#print(sum(abserr)/len(abserr))
|
||||
#print(sum(relerr)/len(relerr))
|
||||
|
||||
abserr = []
|
||||
relerr = []
|
||||
for i in range(100):
|
||||
A1 = torch.randn(1024, 1024, device="cuda")
|
||||
C, SC = F.quantize_blockwise(A1)
|
||||
A2 = F.dequantize_blockwise(C, SC)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff/torch.abs(A1+1e-8)
|
||||
abserr.append(diff.mean().item())
|
||||
relerr.append(reldiff.mean().item())
|
||||
#assert diff < 0.0075
|
||||
#print(3, sum(abserr)/len(abserr))
|
||||
#print(3, sum(relerr)/len(relerr))
|
||||
|
||||
|
||||
def test_few_bit_quant():
|
||||
|
||||
#print('')
|
||||
for bits in range(2, 9):
|
||||
#print('='*30, bits, '='*30)
|
||||
for method in ['linear', 'fp8', 'dynamic', 'quantile']:
|
||||
abserrs = []
|
||||
relerrs = []
|
||||
code = None
|
||||
if method == 'linear':
|
||||
code = F.create_linear_map(True, total_bits=bits).cuda()
|
||||
elif method == 'fp8':
|
||||
ebits = math.ceil(bits/2)
|
||||
pbits = bits-ebits-1
|
||||
code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
|
||||
elif method == 'dynamic':
|
||||
code = F.create_dynamic_map(True, bits-0, bits).cuda()
|
||||
elif method == 'quantile':
|
||||
values = torch.randn(2048, 2048, device='cuda')
|
||||
q = F.estimate_quantiles(values, offset= 1/(2*(2**bits)), num_quantiles=2**bits)
|
||||
gap = 256-q.numel()
|
||||
q = q.tolist()
|
||||
for i in range(gap):
|
||||
q.append(0)
|
||||
q = torch.Tensor(q).cuda()
|
||||
|
||||
q /= q.abs().max()
|
||||
code, idx = torch.sort(q)
|
||||
#print(method, (code==0).sum())
|
||||
assert code.numel() == 256
|
||||
for i in range(10):
|
||||
|
||||
values = torch.randn(1, 32, device='cuda')
|
||||
values /= values.abs().max()
|
||||
#values[values.abs() < 1e-6] += 1e-5
|
||||
|
||||
q1 = []
|
||||
v1 = []
|
||||
for v in values[0]:
|
||||
idx = torch.abs(v-code).argmin()
|
||||
q1.append(idx.item())
|
||||
v1.append(code[idx].item())
|
||||
|
||||
q1 = torch.Tensor(q1).cuda()
|
||||
v1 = torch.Tensor(v1).cuda()
|
||||
|
||||
q2, S2 = F.quantize(values, code=code)
|
||||
v2 = F.dequantize(q2, S2)
|
||||
|
||||
idx = torch.isclose(q1.int(), q2.int())
|
||||
err2 = torch.abs(v2-values)
|
||||
abserrs.append(err2.mean().item())
|
||||
relerrs.append((err2/(1e-10+values).abs()).mean().item())
|
||||
if idx.sum():
|
||||
# some weird cases
|
||||
err1 = torch.abs(v1-values).mean()
|
||||
assert err2.mean() <= err1
|
||||
|
||||
else:
|
||||
torch.testing.assert_allclose(q1, q2)
|
||||
#print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
|
||||
|
||||
|
||||
def test_kbit_quantile_estimation():
|
||||
for i in range(100):
|
||||
data = torch.randn(1024, 1024, device='cuda')
|
||||
for bits in range(2, 9):
|
||||
p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits)
|
||||
val1 = torch.Tensor(norm.ppf(p)).cuda()
|
||||
val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
|
||||
err = torch.abs(val1-val2).mean()
|
||||
assert err < 0.035
|
||||
|
||||
|
||||
def test_bench_dequantization():
|
||||
a = torch.rand(1024, 1024, device='cuda').half()
|
||||
qa, SA = F.quantize_blockwise(a)
|
||||
|
||||
max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000
|
||||
#print(max_theoretical_mu)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(100):
|
||||
F.dequantize_blockwise(qa, SA, blocksize=2048)
|
||||
torch.cuda.synchronize()
|
||||
#print((time.time()-t0)/1e6)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user