Merge branch 'main' into cleanup

This commit is contained in:
Tom Aarsen 2022-11-17 15:22:29 +01:00 committed by GitHub
commit b104ce3b62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 358 additions and 115 deletions

View File

@ -52,8 +52,13 @@ class CUDASetup:
self.add_log_entry('python setup.py install') self.add_log_entry('python setup.py install')
def initialize(self): def initialize(self):
self.cuda_setup_log = [] self.has_printed = False
self.lib = None self.lib = None
self.run_cuda_setup()
def run_cuda_setup(self):
self.initialized = True
self.cuda_setup_log = []
from .cuda_setup.main import evaluate_cuda_setup from .cuda_setup.main import evaluate_cuda_setup
binary_name, cudart_path, cuda, cc, cuda_version_string = evaluate_cuda_setup() binary_name, cudart_path, cuda, cc, cuda_version_string = evaluate_cuda_setup()
@ -89,7 +94,8 @@ class CUDASetup:
else: else:
self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...") self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...")
self.lib = ct.cdll.LoadLibrary(binary_path) self.lib = ct.cdll.LoadLibrary(binary_path)
except: except Exception as ex:
self.add_log_entry(str(ex))
self.print_log_stack() self.print_log_stack()
def add_log_entry(self, msg, is_warning=False): def add_log_entry(self, msg, is_warning=False):

View File

@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import ctypes as ct import ctypes as ct
import itertools
import operator import operator
import random import random
from functools import reduce # Required in Python 3 from functools import reduce # Required in Python 3
@ -130,13 +131,59 @@ class Cusparse_Context:
return cls._instance return cls._instance
def create_linear_map(signed=True): def create_linear_map(signed=True, total_bits=8):
if signed: sign = (-1.0 if signed else 0.0)
return torch.linspace(-1.0, 1.0, 256)
return torch.linspace(0.0, 1.0, 256) values = torch.linspace(sign, 1.0, 2**total_bits)
gap = 256 - values.numel()
if gap == 0:
return values
else:
l = values.numel()//2
#return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist())
return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist())
def create_dynamic_map(signed=True, n=7): def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8):
e = exponent_bits
p = precision_bits
has_sign = 1 if signed else 0
assert e+p == total_bits-has_sign
# the exponent is biased to 2^(e-1) -1 == 0
evalues = []
pvalues = []
for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)):
evalues.append(2**val)
lst = list(itertools.product([0, 1], repeat=precision_bits))
for bit_pattern in lst:
value = 1
for i, pval in enumerate(list(bit_pattern)):
value += pval*(2**-(i+1))
pvalues.append(value)
assert len(evalues)*len(pvalues) == 2**(total_bits-has_sign)
values = []
for ev in evalues:
for pv in pvalues:
if signed:
values.append(-ev*pv)
values.append(ev*pv)
if total_bits < 8:
gap = 256 - len(values)
for i in range(gap):
values.append(0)
values.sort()
code = torch.Tensor(values)
code /= code.max()
code[127] = 0
return code
def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
""" """
Creates the dynamic quantiztion map. Creates the dynamic quantiztion map.
@ -157,28 +204,32 @@ def create_dynamic_map(signed=True, n=7):
# these are additional items that come from the case # these are additional items that come from the case
# where all the exponent bits are zero and no # where all the exponent bits are zero and no
# indicator bit is present # indicator bit is present
additional_items = 2 ** (7 - n) - 1 non_sign_bits = total_bits - (1 if signed else 0)
additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1
if not signed: if not signed:
additional_items = 2 * additional_items additional_items = 2 * additional_items
for i in range(n): for i in range(max_exponent_bits):
fraction_items = ( fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1))
2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1
)
boundaries = torch.linspace(0.1, 1, fraction_items) boundaries = torch.linspace(0.1, 1, fraction_items)
means = (boundaries[:-1] + boundaries[1:]) / 2.0 means = (boundaries[:-1] + boundaries[1:]) / 2.0
data += ((10 ** (-(n - 1) + i)) * means).tolist() data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
if signed: if signed:
data += (-(10 ** (-(n - 1) + i)) * means).tolist() data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
if additional_items > 0: if additional_items > 0:
boundaries = torch.linspace(0.1, 1, additional_items + 1) boundaries = torch.linspace(0.1, 1, additional_items + 1)
means = (boundaries[:-1] + boundaries[1:]) / 2.0 means = (boundaries[:-1] + boundaries[1:]) / 2.0
data += ((10 ** (-(n - 1) + i)) * means).tolist() data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
if signed: if signed:
data += (-(10 ** (-(n - 1) + i)) * means).tolist() data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
data.append(0) data.append(0)
data.append(1.0) data.append(1.0)
gap = 256 - len(data)
for i in range(gap):
data.append(0)
data.sort() data.sort()
return Tensor(data) return Tensor(data)
@ -322,9 +373,7 @@ def nvidia_transform(
return out, new_state return out, new_state
def estimate_quantiles( def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor:
A: Tensor, out: Tensor = None, offset: float = 1 / 512
) -> Tensor:
''' '''
Estimates 256 equidistant quantiles on the input tensor eCDF. Estimates 256 equidistant quantiles on the input tensor eCDF.
@ -344,25 +393,36 @@ def estimate_quantiles(
out : torch.Tensor out : torch.Tensor
Tensor with the 256 estimated quantiles. Tensor with the 256 estimated quantiles.
offset : float offset : float
The offset for the first and last quantile from 0 and 1. Default: 1/512 The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles)
num_quantiles : int
The number of equally spaced quantiles.
Returns Returns
------- -------
torch.Tensor: torch.Tensor:
The 256 quantiles in float32 datatype. The 256 quantiles in float32 datatype.
''' '''
if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.')
if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}")
if num_quantiles < 256 and offset == 1/(512):
# override default arguments
offset = 1/(2*num_quantiles)
if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device) if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
is_on_gpu([A, out]) is_on_gpu([A, out])
device = pre_call(A.device)
if A.dtype == torch.float32: if A.dtype == torch.float32:
lib.cestimate_quantiles_fp32( lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
)
elif A.dtype == torch.float16: elif A.dtype == torch.float16:
lib.cestimate_quantiles_fp16( lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
)
else: else:
raise NotImplementedError(f"Not supported data type {A.dtype}") raise NotImplementedError(f"Not supported data type {A.dtype}")
post_call(device)
if num_quantiles < 256:
idx = torch.linspace(0, 255, num_quantiles).long().to(A.device)
out = out[idx]
return out return out
@ -395,15 +455,14 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
The quantization state to undo the quantization. The quantization state to undo the quantization.
""" """
if code is None: if code is None:
if "dynamic" not in name2qmap: if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device) name2qmap["dynamic"] = create_dynamic_map().to(A.device)
code = name2qmap["dynamic"] code = name2qmap["dynamic"]
code = code.to(A.device)
if absmax is None: if absmax is None:
n = A.numel() n = A.numel()
blocksize = (blocksize if A.device.type == 'cpu' else 4096)
blocks = n // blocksize blocks = n // blocksize
blocks += 1 if n % blocksize > 0 else 0 blocks += 1 if n % blocksize > 0 else 0
absmax = torch.zeros((blocks,), device=A.device) absmax = torch.zeros((blocks,), device=A.device)
@ -412,8 +471,13 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
out = torch.zeros_like(A, dtype=torch.uint8) out = torch.zeros_like(A, dtype=torch.uint8)
if A.device.type != 'cpu': if A.device.type != 'cpu':
is_on_gpu([code, A, absmax, out, rand]) assert blocksize in [4096, 2048, 1024, 512]
cblocksize = ct.c_int32(blocksize)
prev_device = pre_call(A.device)
code = code.to(A.device)
if rand is not None: if rand is not None:
is_on_gpu([code, A, out, absmax, rand])
assert blocksize==4096
assert rand.numel() >= 1024 assert rand.numel() >= 1024
rand_offset = random.randint(0, 1023) rand_offset = random.randint(0, 1023)
if A.dtype == torch.float32: if A.dtype == torch.float32:
@ -421,20 +485,19 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
elif A.dtype == torch.float16: elif A.dtype == torch.float16:
lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel())) lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
else: else:
raise ValueError( raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
)
else: else:
is_on_gpu([code, A, out, absmax])
if A.dtype == torch.float32: if A.dtype == torch.float32:
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel())) lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
elif A.dtype == torch.float16: elif A.dtype == torch.float16:
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel())) lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
else: else:
raise ValueError( raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" post_call(A.device)
)
else: else:
# cpu # cpu
code = code.cpu()
assert rand is None assert rand is None
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
@ -479,27 +542,30 @@ def dequantize_blockwise(
if "dynamic" not in name2qmap: if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device) name2qmap["dynamic"] = create_dynamic_map().to(A.device)
code = name2qmap["dynamic"] code = name2qmap["dynamic"]
code = code.to(A.device)
if out is None: if out is None:
out = torch.zeros_like(A, dtype=torch.float32) out = torch.zeros_like(A, dtype=torch.float32)
if quant_state is None: if quant_state is None:
quant_state = (absmax, code) quant_state = (absmax, code)
else:
absmax, code = quant_state
if A.device.type != 'cpu': if A.device.type != 'cpu':
if blocksize not in [2048, 4096]: device = pre_call(A.device)
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]") code = code.to(A.device)
if blocksize not in [2048, 4096, 1024, 512]:
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512]")
is_on_gpu([A, out]) is_on_gpu([A, out])
if out.dtype == torch.float32: if out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
elif out.dtype == torch.float16: elif out.dtype == torch.float16:
lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
else: else:
raise ValueError( raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" post_call(A.device)
)
else: else:
code = code.cpu()
lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
return out return out

View File

@ -428,16 +428,16 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
} }
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC> template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC>
__launch_bounds__(TH, 4) //__launch_bounds__(TH, 4)
__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n) __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n)
{ {
const int n_full = gridDim.x * BLOCK_SIZE; const int n_full = gridDim.x * BLOCK_SIZE;
int valid_items = 0; int valid_items = 0;
const int base_idx = (blockIdx.x * BLOCK_SIZE); const int base_idx = (blockIdx.x * BLOCK_SIZE);
T vals[NUM]; T vals[NUM_PER_TH];
float rand_vals[NUM]; float rand_vals[NUM_PER_TH];
unsigned char qvals[NUM]; unsigned char qvals[NUM_PER_TH];
//float local_abs_max = -FLT_MAX; //float local_abs_max = -FLT_MAX;
float local_abs_max = 0.0f; float local_abs_max = 0.0f;
int local_rand_idx = 0; int local_rand_idx = 0;
@ -510,15 +510,15 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
} }
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH>
__global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, T *out, const int n) __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n)
{ {
const int n_full = gridDim.x * BLOCK_SIZE; const int n_full = gridDim.x * BLOCK_SIZE;
int valid_items = 0; int valid_items = 0;
const int base_idx = (blockIdx.x * BLOCK_SIZE); const int base_idx = (blockIdx.x * BLOCK_SIZE);
T vals[NUM]; T vals[NUM_PER_TH];
unsigned char qvals[NUM]; unsigned char qvals[NUM_PER_TH];
float local_abs_max = -FLT_MAX; float local_abs_max = -FLT_MAX;
typedef cub::BlockLoad<unsigned char, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar; typedef cub::BlockLoad<unsigned char, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
@ -526,10 +526,11 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c
__shared__ typename LoadChar::TempStorage loadchar; __shared__ typename LoadChar::TempStorage loadchar;
__shared__ typename StoreT::TempStorage storet; __shared__ typename StoreT::TempStorage storet;
__shared__ float smem_code[256]; //__shared__ float smem_code[256];
//float local_code[16];
if(threadIdx.x < 256) //if(threadIdx.x < 256)
smem_code[threadIdx.x] = code[threadIdx.x]; //smem_code[threadIdx.x] = code[threadIdx.x];
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{ {
@ -539,9 +540,10 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c
__syncthreads(); __syncthreads();
LoadChar(loadchar).Load(&(A[i]), qvals, valid_items, 128); LoadChar(loadchar).Load(&(A[i]), qvals, valid_items, 128);
// load code through read-only cache via __ldg
#pragma unroll NUM_PER_TH #pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++) for(int j = 0; j < NUM_PER_TH; j++)
vals[j] = smem_code[qvals[j]]*local_abs_max; vals[j] = __ldg(&code[qvals[j]])*local_abs_max;
__syncthreads(); __syncthreads();
StoreT(storet).Store(&(out[i]), vals, valid_items); StoreT(storet).Store(&(out[i]), vals, valid_items);
@ -2791,11 +2793,21 @@ template __global__ void kQuantizeBlockwise<half, 4096, 4, 0>(float * code, half
template __global__ void kQuantizeBlockwise<float, 4096, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise<float, 4096, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 4096, 4, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise<half, 4096, 4, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 4096, 4, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise<float, 4096, 4, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 2048, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 2048, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 1024, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 1024, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 512, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 512, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); template __global__ void kDequantizeBlockwise<half, 2048, 512, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); template __global__ void kDequantizeBlockwise<float, 2048, 512, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 1024, 256, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 1024, 256, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 512, 256, 2>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 512, 256, 2>(float *code, unsigned char * A, float * absmax, float *out, const int n);

View File

@ -15,7 +15,7 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n); __global__ void kDequantize(float *code, unsigned char *A, float *out, const int n);
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC> __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC> __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, T *out, const int n); template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n);
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS> template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, __global__ void kPreconditionOptimizer32bit2State(T* g, T* p,

View File

@ -50,11 +50,23 @@ void dequantize(float *code, unsigned char *A, float *out, int n)
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n) template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n)
{ {
int num_blocks = n/4096; int num_blocks = n/blocksize;
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n); if(STOCHASTIC == 1)
assert(blocksize == 4096);
if(blocksize == 4096)
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 2048)
kQuantizeBlockwise<T, 2048, 4, 0><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 1024)
kQuantizeBlockwise<T, 1024, 4, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 512)
kQuantizeBlockwise<T, 512, 2, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
@ -66,6 +78,11 @@ template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, flo
kDequantizeBlockwise<T, 4096, 1024, 4><<<num_blocks, 4096/4>>>(code, A, absmax, out, n); kDequantizeBlockwise<T, 4096, 1024, 4><<<num_blocks, 4096/4>>>(code, A, absmax, out, n);
else if(blocksize == 2048) else if(blocksize == 2048)
kDequantizeBlockwise<T, 2048, 512, 4><<<num_blocks, 2048/4>>>(code, A, absmax, out, n); kDequantizeBlockwise<T, 2048, 512, 4><<<num_blocks, 2048/4>>>(code, A, absmax, out, n);
else if(blocksize == 1024)
kDequantizeBlockwise<T, 1024, 256, 4><<<num_blocks, 1024/4>>>(code, A, absmax, out, n);
else if(blocksize == 512)
kDequantizeBlockwise<T, 512, 256, 2><<<num_blocks, 512/2>>>(code, A, absmax, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
@ -659,10 +676,10 @@ template void transformRowToFormat<COL_AMPERE, 1>(char * A, char *out, int rows,
template void estimateQuantiles(half *A, float *code, float offset, int n); template void estimateQuantiles(half *A, float *code, float offset, int n);
template void estimateQuantiles(float *A, float *code, float offset, int n); template void estimateQuantiles(float *A, float *code, float offset, int n);
template void quantizeBlockwise<half, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); template void quantizeBlockwise<half, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); template void quantizeBlockwise<float, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<half, 1>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); template void quantizeBlockwise<half, 1>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 1>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); template void quantizeBlockwise<float, 1>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void dequantizeBlockwise<half>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); template void dequantizeBlockwise<half>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
template void dequantizeBlockwise<float>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); template void dequantizeBlockwise<float>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);

View File

@ -128,7 +128,7 @@ template <typename T> void estimateQuantiles(T *A, float *code, float offset, in
void quantize(float *code, float *A, unsigned char *out, int n); void quantize(float *code, float *A, unsigned char *out, int n);
void dequantize(float *code, unsigned char *A, float *out, int n); void dequantize(float *code, unsigned char *A, float *out, int n);
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n);
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,

View File

@ -75,10 +75,10 @@ MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32)
void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); } void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping<half>(g, gnorm_vec, step, n); } void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping<half>(g, gnorm_vec, step, n); }
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise<half, 0>(code, A, absmax, out, NULL, 0, n); } void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise<float, 0>(code, A, absmax, out, NULL, 0, n); } void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<half, 1>(code, A, absmax, out, rand, rand_offset, n); } void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<half, 1>(code, A, absmax, out, rand, rand_offset, 4096, n); }
void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<float, 1>(code, A, absmax, out, rand, rand_offset, n); } void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<float, 1>(code, A, absmax, out, rand, rand_offset, 4096, n); }
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half>(code, A, absmax, out, blocksize, n); } \ void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half>(code, A, absmax, out, blocksize, n); } \
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); } void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); }
@ -140,8 +140,8 @@ extern "C"
void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); } void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); }
void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); }
void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); } void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); }
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, n); } void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, n); } void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, n); } void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, n); }
void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, n); } void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, n); }

View File

@ -6,12 +6,14 @@ from itertools import product
import einops import einops
import pytest import pytest
import torch import torch
import numpy as np
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes import functional as F from bitsandbytes import functional as F
from scipy.stats import norm
torch.set_printoptions( torch.set_printoptions(
precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000 precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
) )
k = 20 k = 20
@ -149,30 +151,41 @@ def test_dynamic_quantization():
def test_dynamic_blockwise_quantization(): def test_dynamic_blockwise_quantization():
diffs = [] #print('')
reldiffs = [] for blocksize in [4096, 2048, 1024, 512]:
for i in range(100): diffs = []
A1 = torch.randn(1024, 1024, device="cuda") reldiffs = []
C, S = F.quantize_blockwise(A1) for i in range(100):
A2 = F.dequantize_blockwise(C, S) A1 = torch.randn(1024, 1024, device="cuda")
diff = torch.abs(A1 - A2) C, S = F.quantize_blockwise(A1, blocksize=blocksize)
reldiff = diff / torch.abs(A1 + 1e-8) A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
diffs.append(diff.mean().item()) diff = torch.abs(A1 - A2)
reldiffs.append(reldiff.mean().item()) reldiff = diff / torch.abs(A1 + 1e-8)
assert diffs[-1] < 0.011 diffs.append(diff.mean().item())
# print(sum(diffs)/len(diffs)) reldiffs.append(reldiff.mean().item())
# print(sum(reldiffs)/len(reldiffs)) abserr = sum(diffs)/len(diffs)
relerr = sum(reldiffs)/len(reldiffs)
assert abserr < 0.011
assert relerr < 0.018
#print('randn', blocksize, sum(diffs)/len(diffs))
#print('randn', blocksize, sum(reldiffs)/len(reldiffs))
diffs = [] diffs = []
for i in range(100): for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda") A1 = torch.rand(1024, 1024, device="cuda")
C, S = F.quantize_blockwise(A1) C, S = F.quantize_blockwise(A1, blocksize=blocksize)
A2 = F.dequantize_blockwise(C, S) A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
diff = torch.abs(A1 - A2).mean().item() diff = torch.abs(A1 - A2)
assert diff < 0.0033 reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff) diffs.append(diff.mean().item())
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) reldiffs.append(reldiff.mean().item())
# print(sum(diffs)/len(diffs)) #torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
abserr = sum(diffs)/len(diffs)
relerr = sum(reldiffs)/len(reldiffs)
assert abserr < 0.0035
assert relerr < 0.015
#print('rand', blocksize, sum(diffs)/len(diffs))
#print('rand', blocksize, sum(reldiffs)/len(reldiffs))
def test_dynamic_blockwise_stochastic_quantization(): def test_dynamic_blockwise_stochastic_quantization():
@ -1616,17 +1629,6 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
# print(time.time() - t0) # print(time.time() - t0)
def test_layout():
a1 = torch.rand(16, 64, device="cuda", dtype=torch.float16)
a1 = torch.arange(16 * 64, device="cuda").reshape(16, 64).byte()
a2, s2 = F.transform(a1, "col_turing")
print(a2.shape)
print(a1.flatten()[8 * 64 : 8 * 64 + 32])
for i in range(4):
print(a2.flatten()[i * 8 * 32 : i * 8 * 32 + 32], 0)
def test_coo2csr(): def test_coo2csr():
threshold = 1 threshold = 1
A = torch.randn(128, 128).half().cuda() A = torch.randn(128, 128).half().cuda()
@ -2040,3 +2042,143 @@ def test_blockwise_cpu_large():
assert diffs[-1] < 0.011 assert diffs[-1] < 0.011
# print(sum(diffs)/len(diffs)) # print(sum(diffs)/len(diffs))
# print(sum(reldiffs)/len(reldiffs)) # print(sum(reldiffs)/len(reldiffs))
def test_fp8_quant():
for e_bits in range(1, 7):
p_bits = 7-e_bits
code = F.create_fp8_map(True, e_bits, p_bits).cuda()
print(e_bits, p_bits)
abserr = []
relerr = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
#assert diff < 0.0075
#print(sum(abserr)/len(abserr))
#print(sum(relerr)/len(relerr))
abserr = []
relerr = []
for i in range(100):
A1 = torch.rand(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
#assert diff < 0.0075
#print(sum(abserr)/len(abserr))
#print(sum(relerr)/len(relerr))
abserr = []
relerr = []
for i in range(100):
A1 = torch.randn(1024, 1024, device="cuda")
C, SC = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, SC)
diff = torch.abs(A1 - A2)
reldiff = diff/torch.abs(A1+1e-8)
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
#assert diff < 0.0075
#print(3, sum(abserr)/len(abserr))
#print(3, sum(relerr)/len(relerr))
def test_few_bit_quant():
#print('')
for bits in range(2, 9):
#print('='*30, bits, '='*30)
for method in ['linear', 'fp8', 'dynamic', 'quantile']:
abserrs = []
relerrs = []
code = None
if method == 'linear':
code = F.create_linear_map(True, total_bits=bits).cuda()
elif method == 'fp8':
ebits = math.ceil(bits/2)
pbits = bits-ebits-1
code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
elif method == 'dynamic':
code = F.create_dynamic_map(True, bits-0, bits).cuda()
elif method == 'quantile':
values = torch.randn(2048, 2048, device='cuda')
q = F.estimate_quantiles(values, offset= 1/(2*(2**bits)), num_quantiles=2**bits)
gap = 256-q.numel()
q = q.tolist()
for i in range(gap):
q.append(0)
q = torch.Tensor(q).cuda()
q /= q.abs().max()
code, idx = torch.sort(q)
#print(method, (code==0).sum())
assert code.numel() == 256
for i in range(10):
values = torch.randn(1, 32, device='cuda')
values /= values.abs().max()
#values[values.abs() < 1e-6] += 1e-5
q1 = []
v1 = []
for v in values[0]:
idx = torch.abs(v-code).argmin()
q1.append(idx.item())
v1.append(code[idx].item())
q1 = torch.Tensor(q1).cuda()
v1 = torch.Tensor(v1).cuda()
q2, S2 = F.quantize(values, code=code)
v2 = F.dequantize(q2, S2)
idx = torch.isclose(q1.int(), q2.int())
err2 = torch.abs(v2-values)
abserrs.append(err2.mean().item())
relerrs.append((err2/(1e-10+values).abs()).mean().item())
if idx.sum():
# some weird cases
err1 = torch.abs(v1-values).mean()
assert err2.mean() <= err1
else:
torch.testing.assert_allclose(q1, q2)
#print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
def test_kbit_quantile_estimation():
for i in range(100):
data = torch.randn(1024, 1024, device='cuda')
for bits in range(2, 9):
p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits)
val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
err = torch.abs(val1-val2).mean()
assert err < 0.035
def test_bench_dequantization():
a = torch.rand(1024, 1024, device='cuda').half()
qa, SA = F.quantize_blockwise(a)
max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000
#print(max_theoretical_mu)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
F.dequantize_blockwise(qa, SA, blocksize=2048)
torch.cuda.synchronize()
#print((time.time()-t0)/1e6)