Merge branch 'main' into cleanup

This commit is contained in:
Tim Dettmers 2023-01-02 11:19:16 +01:00 committed by GitHub
commit c91f592ad7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 100 additions and 34 deletions

View File

@ -6,10 +6,12 @@ import ctypes as ct
import itertools
import operator
import random
import torch
import itertools
import math
from functools import reduce # Required in Python 3
from typing import Tuple
import torch
from torch import Tensor
from .cextension import COMPILED_WITH_CUDA, lib
@ -131,10 +133,17 @@ class Cusparse_Context:
return cls._instance
def create_linear_map(signed=True, total_bits=8):
def create_linear_map(signed=True, total_bits=8, add_zero=True):
sign = (-1.0 if signed else 0.0)
total_values = 2**total_bits
if add_zero or total_bits < 8:
# add a zero
# since we simulate less bits by having zeros in the data type, we
# we need to center the quantization around zero and as such lose
# a single value
total_values = (2**total_bits if not signed else 2**total_bits-1)
values = torch.linspace(sign, 1.0, 2**total_bits)
values = torch.linspace(sign, 1.0, total_values)
gap = 256 - values.numel()
if gap == 0:
return values
@ -156,20 +165,28 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
evalues.append(2**val)
lst = list(itertools.product([0, 1], repeat=precision_bits))
for bit_pattern in lst:
value = 1
for i, pval in enumerate(list(bit_pattern)):
value += pval*(2**-(i+1))
pvalues.append(value)
assert len(evalues)*len(pvalues) == 2**(total_bits-has_sign)
values = []
for ev in evalues:
for pv in pvalues:
lst = list(itertools.product([0, 1], repeat=precision_bits))
#for ev in evalues:
bias = 2**(exponent_bits-1)-1
for evalue in range(2**(exponent_bits)):
for bit_pattern in lst:
value = (1 if evalue != 0 else 0)
for i, pval in enumerate(list(bit_pattern)):
value += pval*(2**-(i+1))
if evalue == 0:
# subnormals
value = value*2**-(bias-1)
else:
# normals
value = value*2**-(evalue-bias-2)
values.append(value)
if signed:
values.append(-ev*pv)
values.append(ev*pv)
values.append(-value)
assert len(values) == 2**total_bits
values.sort()
if total_bits < 8:
gap = 256 - len(values)
for i in range(gap):
@ -177,7 +194,6 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
values.sort()
code = torch.Tensor(values)
code /= code.max()
code[127] = 0
return code
@ -233,6 +249,20 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
data.sort()
return Tensor(data)
def create_quantile_map(A, total_bits=8):
q = estimate_quantiles(A, num_quantiles=2**total_bits-1)
q = q.tolist()
q.append(0)
gap = 256 - len(q)
for i in range(gap):
q.append(0)
q.sort()
q = Tensor(q)
q = q/q.abs().max()
return q
def get_special_format_str():
if not torch.cuda.is_available(): return 'col_turing'
@ -420,6 +450,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
post_call(device)
if num_quantiles < 256:
step = round(256/num_quantiles)
idx = torch.linspace(0, 255, num_quantiles).long().to(A.device)
out = out[idx]
@ -471,7 +502,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
out = torch.zeros_like(A, dtype=torch.uint8)
if A.device.type != 'cpu':
assert blocksize in [4096, 2048, 1024, 512]
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
cblocksize = ct.c_int32(blocksize)
prev_device = pre_call(A.device)
code = code.to(A.device)
@ -554,8 +585,8 @@ def dequantize_blockwise(
if A.device.type != 'cpu':
device = pre_call(A.device)
code = code.to(A.device)
if blocksize not in [2048, 4096, 1024, 512]:
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512]")
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
is_on_gpu([A, out])
if out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))

View File

@ -454,8 +454,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
__shared__ float smem_code[256];
__shared__ float smem_absmax_value[1];
if(threadIdx.x < 256)
smem_code[threadIdx.x] = code[threadIdx.x];
for(int i = threadIdx.x; i < 256; i+=blockDim.x)
smem_code[i] = code[i];
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{
@ -2799,6 +2799,12 @@ template __global__ void kQuantizeBlockwise<half, 1024, 4, 0>(float * code, half
template __global__ void kQuantizeBlockwise<float, 1024, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 512, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 512, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 256, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 256, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 128, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 128, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 64, 1, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 64, 1, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
@ -2808,6 +2814,12 @@ template __global__ void kDequantizeBlockwise<half, 1024, 256, 4>(float *code, u
template __global__ void kDequantizeBlockwise<float, 1024, 256, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 512, 256, 2>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 512, 256, 2>(float *code, unsigned char * A, float * absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 256, 128, 2>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 256, 128, 2>(float *code, unsigned char * A, float * absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 128, 64, 2>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 128, 64, 2>(float *code, unsigned char * A, float * absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 64, 64, 1>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 64, 64, 1>(float *code, unsigned char * A, float * absmax, float *out, const int n);

View File

@ -65,6 +65,12 @@ template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A,
kQuantizeBlockwise<T, 1024, 4, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 512)
kQuantizeBlockwise<T, 512, 2, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 256)
kQuantizeBlockwise<T, 256, 2, 0><<<num_blocks, 128>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 128)
kQuantizeBlockwise<T, 128, 2, 0><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 64)
kQuantizeBlockwise<T, 64, 1, 0><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
@ -82,6 +88,12 @@ template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, flo
kDequantizeBlockwise<T, 1024, 256, 4><<<num_blocks, 1024/4>>>(code, A, absmax, out, n);
else if(blocksize == 512)
kDequantizeBlockwise<T, 512, 256, 2><<<num_blocks, 512/2>>>(code, A, absmax, out, n);
else if(blocksize == 256)
kDequantizeBlockwise<T, 256, 128, 2><<<num_blocks, 256/2>>>(code, A, absmax, out, n);
else if(blocksize == 128)
kDequantizeBlockwise<T, 128, 64, 2><<<num_blocks, 128/2>>>(code, A, absmax, out, n);
else if(blocksize == 64)
kDequantizeBlockwise<T, 64, 64, 1><<<num_blocks, 64/1>>>(code, A, absmax, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

View File

@ -2113,15 +2113,11 @@ def test_few_bit_quant():
code = F.create_dynamic_map(True, bits-0, bits).cuda()
elif method == 'quantile':
values = torch.randn(2048, 2048, device='cuda')
q = F.estimate_quantiles(values, offset= 1/(2*(2**bits)), num_quantiles=2**bits)
gap = 256-q.numel()
q = q.tolist()
for i in range(gap):
q.append(0)
q = torch.Tensor(q).cuda()
q /= q.abs().max()
code, idx = torch.sort(q)
code = F.create_quantile_map(values, bits).cuda()
# for some data types we have no zero
# for some data types we have one zero
# for some data types we have two zeros
assert torch.unique(code).numel() in [2**bits, 2**bits-1], f'bits: {bits}, method: {method}'
#print(method, (code==0).sum())
assert code.numel() == 256
for i in range(10):
@ -2140,8 +2136,8 @@ def test_few_bit_quant():
q1 = torch.Tensor(q1).cuda()
v1 = torch.Tensor(v1).cuda()
q2, S2 = F.quantize(values, code=code)
v2 = F.dequantize(q2, S2)
q2, S2 = F.quantize_blockwise(values, code=code)
v2 = F.dequantize_blockwise(q2, S2)
idx = torch.isclose(q1.int(), q2.int())
err2 = torch.abs(v2-values)
@ -2150,11 +2146,12 @@ def test_few_bit_quant():
if idx.sum():
# some weird cases
err1 = torch.abs(v1-values).mean()
assert err2.mean() <= err1
#assert err2.mean() <= err1
else:
torch.testing.assert_allclose(q1, q2)
#print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
#assert False
def test_kbit_quantile_estimation():
@ -2165,6 +2162,20 @@ def test_kbit_quantile_estimation():
val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
err = torch.abs(val1-val2).mean()
assert err < 0.038
for i in range(100):
data = torch.randn(1024, 1024, device='cuda')
for bits in range(2, 4):
total_values = 2**bits-1
p = np.linspace(0, 1, 2*total_values+1)
idx = np.arange(1, 2*total_values+1, 2)
p = p[idx]
offset = 1/(2*total_values)
p = np.linspace(offset, 1-offset, total_values)
val1 = torch.Tensor(norm.ppf(p)).cuda()
val2 = F.estimate_quantiles(data, num_quantiles=2**bits-1)
err = torch.abs(val1-val2).mean()
assert err < 0.035