Added blocksizes 2048, 1024, and 512 to blockwise quant.
This commit is contained in:
parent
2f2063bac2
commit
6bc2b992be
|
@ -52,8 +52,13 @@ class CUDASetup(object):
|
|||
self.add_log_entry('python setup.py install')
|
||||
|
||||
def initialize(self):
|
||||
self.cuda_setup_log = []
|
||||
self.has_printed = False
|
||||
self.lib = None
|
||||
self.run_cuda_setup()
|
||||
|
||||
def run_cuda_setup(self):
|
||||
self.initialized = True
|
||||
self.cuda_setup_log = []
|
||||
|
||||
from .cuda_setup.main import evaluate_cuda_setup
|
||||
binary_name, cudart_path, cuda, cc, cuda_version_string = evaluate_cuda_setup()
|
||||
|
@ -89,7 +94,9 @@ class CUDASetup(object):
|
|||
else:
|
||||
self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...")
|
||||
self.lib = ct.cdll.LoadLibrary(binary_path)
|
||||
except:
|
||||
print(self.lib)
|
||||
except Exception as ex:
|
||||
self.add_log_entry(str(ex))
|
||||
self.print_log_stack()
|
||||
|
||||
def add_log_entry(self, msg, is_warning=False):
|
||||
|
|
|
@ -130,10 +130,10 @@ class Cusparse_Context(object):
|
|||
return cls._instance
|
||||
|
||||
|
||||
def create_linear_map(signed=True, bits=8):
|
||||
def create_linear_map(signed=True, total_bits=8):
|
||||
sign = (-1.0 if signed else 0.0)
|
||||
|
||||
values = torch.linspace(sign, 1.0, 2**bits)
|
||||
values = torch.linspace(sign, 1.0, 2**total_bits)
|
||||
gap = 256 - values.numel()
|
||||
if gap == 0:
|
||||
return values
|
||||
|
@ -457,6 +457,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
|||
The quantization state to undo the quantization.
|
||||
"""
|
||||
|
||||
|
||||
if code is None:
|
||||
if "dynamic" not in name2qmap:
|
||||
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
|
||||
|
@ -474,8 +475,11 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
|||
out = torch.zeros_like(A, dtype=torch.uint8)
|
||||
|
||||
if A.device.type != 'cpu':
|
||||
assert blocksize in [4096, 2048, 1024, 512]
|
||||
is_on_gpu([code, A, absmax, out, rand])
|
||||
cblocksize = ct.c_int32(blocksize)
|
||||
if rand is not None:
|
||||
assert blocksize==4096
|
||||
assert rand.numel() >= 1024
|
||||
rand_offset = random.randint(0, 1023)
|
||||
if A.dtype == torch.float32:
|
||||
|
@ -483,18 +487,14 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
|||
elif A.dtype == torch.float16:
|
||||
lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
|
||||
)
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
else:
|
||||
if A.dtype == torch.float32:
|
||||
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
|
||||
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||
elif A.dtype == torch.float16:
|
||||
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
|
||||
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
|
||||
)
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
else:
|
||||
# cpu
|
||||
assert rand is None
|
||||
|
|
|
@ -428,16 +428,16 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
|
|||
}
|
||||
|
||||
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC>
|
||||
__launch_bounds__(TH, 4)
|
||||
//__launch_bounds__(TH, 4)
|
||||
__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n)
|
||||
{
|
||||
const int n_full = gridDim.x * BLOCK_SIZE;
|
||||
int valid_items = 0;
|
||||
const int base_idx = (blockIdx.x * BLOCK_SIZE);
|
||||
|
||||
T vals[NUM];
|
||||
float rand_vals[NUM];
|
||||
unsigned char qvals[NUM];
|
||||
T vals[NUM_PER_TH];
|
||||
float rand_vals[NUM_PER_TH];
|
||||
unsigned char qvals[NUM_PER_TH];
|
||||
//float local_abs_max = -FLT_MAX;
|
||||
float local_abs_max = 0.0f;
|
||||
int local_rand_idx = 0;
|
||||
|
@ -517,8 +517,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c
|
|||
int valid_items = 0;
|
||||
const int base_idx = (blockIdx.x * BLOCK_SIZE);
|
||||
|
||||
T vals[NUM];
|
||||
unsigned char qvals[NUM];
|
||||
T vals[NUM_PER_TH];
|
||||
unsigned char qvals[NUM_PER_TH];
|
||||
float local_abs_max = -FLT_MAX;
|
||||
|
||||
typedef cub::BlockLoad<unsigned char, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
|
||||
|
@ -2791,11 +2791,21 @@ template __global__ void kQuantizeBlockwise<half, 4096, 4, 0>(float * code, half
|
|||
template __global__ void kQuantizeBlockwise<float, 4096, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 4096, 4, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 4096, 4, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 2048, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 2048, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 1024, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 1024, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 512, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 512, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
|
||||
template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 1024, 256, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 1024, 256, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 512, 256, 2>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 256, 2>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
|
||||
|
||||
|
||||
|
||||
|
|
33
csrc/ops.cu
33
csrc/ops.cu
|
@ -50,11 +50,23 @@ void dequantize(float *code, unsigned char *A, float *out, int n)
|
|||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n)
|
||||
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n)
|
||||
{
|
||||
int num_blocks = n/4096;
|
||||
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
|
||||
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
int num_blocks = n/blocksize;
|
||||
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
|
||||
if(STOCHASTIC == 1)
|
||||
assert(blocksize == 4096);
|
||||
|
||||
if(blocksize == 4096)
|
||||
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 2048)
|
||||
kQuantizeBlockwise<T, 2048, 4, 0><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 1024)
|
||||
kQuantizeBlockwise<T, 1024, 4, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 512)
|
||||
kQuantizeBlockwise<T, 512, 2, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
|
||||
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
@ -66,6 +78,11 @@ template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, flo
|
|||
kDequantizeBlockwise<T, 4096, 1024, 4><<<num_blocks, 4096/4>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 2048)
|
||||
kDequantizeBlockwise<T, 2048, 512, 4><<<num_blocks, 2048/4>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 1024)
|
||||
kDequantizeBlockwise<T, 1024, 256, 4><<<num_blocks, 1024/4>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 512)
|
||||
kDequantizeBlockwise<T, 512, 256, 2><<<num_blocks, 512/2>>>(code, A, absmax, out, n);
|
||||
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
@ -659,10 +676,10 @@ template void transformRowToFormat<COL_AMPERE, 1>(char * A, char *out, int rows,
|
|||
template void estimateQuantiles(half *A, float *code, float offset, int n);
|
||||
template void estimateQuantiles(float *A, float *code, float offset, int n);
|
||||
|
||||
template void quantizeBlockwise<half, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
|
||||
template void quantizeBlockwise<float, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
|
||||
template void quantizeBlockwise<half, 1>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
|
||||
template void quantizeBlockwise<float, 1>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
|
||||
template void quantizeBlockwise<half, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<half, 1>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 1>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<float>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||
|
||||
|
|
|
@ -128,7 +128,7 @@ template <typename T> void estimateQuantiles(T *A, float *code, float offset, in
|
|||
|
||||
void quantize(float *code, float *A, unsigned char *out, int n);
|
||||
void dequantize(float *code, unsigned char *A, float *out, int n);
|
||||
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
|
||||
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n);
|
||||
|
||||
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
||||
|
|
|
@ -75,10 +75,10 @@ MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32)
|
|||
void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
|
||||
void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping<half>(g, gnorm_vec, step, n); }
|
||||
|
||||
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise<half, 0>(code, A, absmax, out, NULL, 0, n); }
|
||||
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise<float, 0>(code, A, absmax, out, NULL, 0, n); }
|
||||
void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<half, 1>(code, A, absmax, out, rand, rand_offset, n); }
|
||||
void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<float, 1>(code, A, absmax, out, rand, rand_offset, n); }
|
||||
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<half, 1>(code, A, absmax, out, rand, rand_offset, 4096, n); }
|
||||
void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<float, 1>(code, A, absmax, out, rand, rand_offset, 4096, n); }
|
||||
|
||||
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half>(code, A, absmax, out, blocksize, n); } \
|
||||
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); }
|
||||
|
@ -140,8 +140,8 @@ extern "C"
|
|||
void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); }
|
||||
void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); }
|
||||
void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); }
|
||||
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, n); }
|
||||
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, n); }
|
||||
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, n); }
|
||||
void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, n); }
|
||||
|
||||
|
|
|
@ -151,30 +151,41 @@ def test_dynamic_quantization():
|
|||
|
||||
|
||||
def test_dynamic_blockwise_quantization():
|
||||
diffs = []
|
||||
reldiffs = []
|
||||
for i in range(100):
|
||||
A1 = torch.randn(1024, 1024, device="cuda")
|
||||
C, S = F.quantize_blockwise(A1)
|
||||
A2 = F.dequantize_blockwise(C, S)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff / torch.abs(A1 + 1e-8)
|
||||
diffs.append(diff.mean().item())
|
||||
reldiffs.append(reldiff.mean().item())
|
||||
assert diffs[-1] < 0.011
|
||||
# print(sum(diffs)/len(diffs))
|
||||
# print(sum(reldiffs)/len(reldiffs))
|
||||
#print('')
|
||||
for blocksize in [4096, 2048, 1024, 512]:
|
||||
diffs = []
|
||||
reldiffs = []
|
||||
for i in range(100):
|
||||
A1 = torch.randn(1024, 1024, device="cuda")
|
||||
C, S = F.quantize_blockwise(A1)
|
||||
A2 = F.dequantize_blockwise(C, S)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff / torch.abs(A1 + 1e-8)
|
||||
diffs.append(diff.mean().item())
|
||||
reldiffs.append(reldiff.mean().item())
|
||||
abserr = sum(diffs)/len(diffs)
|
||||
relerr = sum(reldiffs)/len(reldiffs)
|
||||
assert abserr < 0.011
|
||||
assert relerr < 0.018
|
||||
#print('randn', blocksize, sum(diffs)/len(diffs))
|
||||
#print('randn', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
|
||||
diffs = []
|
||||
for i in range(100):
|
||||
A1 = torch.rand(1024, 1024, device="cuda")
|
||||
C, S = F.quantize_blockwise(A1)
|
||||
A2 = F.dequantize_blockwise(C, S)
|
||||
diff = torch.abs(A1 - A2).mean().item()
|
||||
assert diff < 0.0033
|
||||
diffs.append(diff)
|
||||
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
|
||||
# print(sum(diffs)/len(diffs))
|
||||
diffs = []
|
||||
for i in range(100):
|
||||
A1 = torch.rand(1024, 1024, device="cuda")
|
||||
C, S = F.quantize_blockwise(A1)
|
||||
A2 = F.dequantize_blockwise(C, S)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff / torch.abs(A1 + 1e-8)
|
||||
diffs.append(diff.mean().item())
|
||||
reldiffs.append(reldiff.mean().item())
|
||||
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
|
||||
abserr = sum(diffs)/len(diffs)
|
||||
relerr = sum(reldiffs)/len(reldiffs)
|
||||
assert abserr < 0.0035
|
||||
assert relerr < 0.015
|
||||
#print('rand', blocksize, sum(diffs)/len(diffs))
|
||||
#print('rand', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
|
||||
|
||||
def test_dynamic_blockwise_stochastic_quantization():
|
||||
|
@ -1618,17 +1629,6 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
|
|||
# print(time.time() - t0)
|
||||
|
||||
|
||||
def test_layout():
|
||||
a1 = torch.rand(16, 64, device="cuda", dtype=torch.float16)
|
||||
a1 = torch.arange(16 * 64, device="cuda").reshape(16, 64).byte()
|
||||
a2, s2 = F.transform(a1, "col_turing")
|
||||
print(a2.shape)
|
||||
|
||||
print(a1.flatten()[8 * 64 : 8 * 64 + 32])
|
||||
for i in range(4):
|
||||
print(a2.flatten()[i * 8 * 32 : i * 8 * 32 + 32], 0)
|
||||
|
||||
|
||||
def test_coo2csr():
|
||||
threshold = 1
|
||||
A = torch.randn(128, 128).half().cuda()
|
||||
|
@ -2062,8 +2062,8 @@ def test_fp8_quant():
|
|||
abserr.append(diff.mean().item())
|
||||
relerr.append(reldiff.mean().item())
|
||||
#assert diff < 0.0075
|
||||
print(sum(abserr)/len(abserr))
|
||||
print(sum(relerr)/len(relerr))
|
||||
#print(sum(abserr)/len(abserr))
|
||||
#print(sum(relerr)/len(relerr))
|
||||
|
||||
abserr = []
|
||||
relerr = []
|
||||
|
@ -2076,8 +2076,8 @@ def test_fp8_quant():
|
|||
abserr.append(diff.mean().item())
|
||||
relerr.append(reldiff.mean().item())
|
||||
#assert diff < 0.0075
|
||||
print(sum(abserr)/len(abserr))
|
||||
print(sum(relerr)/len(relerr))
|
||||
#print(sum(abserr)/len(abserr))
|
||||
#print(sum(relerr)/len(relerr))
|
||||
|
||||
abserr = []
|
||||
relerr = []
|
||||
|
@ -2090,21 +2090,21 @@ def test_fp8_quant():
|
|||
abserr.append(diff.mean().item())
|
||||
relerr.append(reldiff.mean().item())
|
||||
#assert diff < 0.0075
|
||||
print(3, sum(abserr)/len(abserr))
|
||||
print(3, sum(relerr)/len(relerr))
|
||||
#print(3, sum(abserr)/len(abserr))
|
||||
#print(3, sum(relerr)/len(relerr))
|
||||
|
||||
|
||||
def test_few_bit_quant():
|
||||
|
||||
print('')
|
||||
#print('')
|
||||
for bits in range(2, 9):
|
||||
print('='*30, bits, '='*30)
|
||||
#print('='*30, bits, '='*30)
|
||||
for method in ['linear', 'fp8', 'dynamic', 'quantile']:
|
||||
abserrs = []
|
||||
relerrs = []
|
||||
code = None
|
||||
if method == 'linear':
|
||||
code = F.create_linear_map(True, bits=bits).cuda()
|
||||
code = F.create_linear_map(True, total_bits=bits).cuda()
|
||||
elif method == 'fp8':
|
||||
ebits = math.ceil(bits/2)
|
||||
pbits = bits-ebits-1
|
||||
|
@ -2122,7 +2122,7 @@ def test_few_bit_quant():
|
|||
|
||||
q /= q.abs().max()
|
||||
code, idx = torch.sort(q)
|
||||
print(method, (code==0).sum())
|
||||
#print(method, (code==0).sum())
|
||||
assert code.numel() == 256
|
||||
for i in range(10):
|
||||
|
||||
|
@ -2154,7 +2154,7 @@ def test_few_bit_quant():
|
|||
|
||||
else:
|
||||
torch.testing.assert_allclose(q1, q2)
|
||||
print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
|
||||
#print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
|
||||
|
||||
|
||||
def test_kbit_quantile_estimation():
|
||||
|
|
Loading…
Reference in New Issue
Block a user