diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index b38ba1d..969250a 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -9,6 +9,8 @@ import random import torch import itertools import math +import scipy.stats +import numpy as np from functools import reduce # Required in Python 3 from typing import Tuple @@ -152,6 +154,70 @@ def create_linear_map(signed=True, total_bits=8, add_zero=True): #return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist()) return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist()) +def custom_map(seed=0, scale=0.01): + v = [12, 10, 8, 6, 3, 2, 1] + # 16-bit 7B 22.33, 4-bit best 22.88, FP4 23.25, 4-bit 95 22.97, 4-bit evo 22.45 + # 16-bit 13B 70.35, 4-bit best 67.16, FP4 100.78, 4-bit-95 69.39, 4-bit evo 70.48 + + # 13B 100 steps: + # - 4-bit evo: 86.02 + # - 4-bit norm: 78.73 + # - 4-bit FP4: + # - 16-bit: + + # interval search on normal distribution + #v = [3.090232306167813, 1.4589770349449647, 1.064410327932115, 0.7896806653244509, 0.5646884166925807, 0.3653406435875121, 0.17964844284441311] # 0.999 26.5 + #v = [2.3263478740408408, 1.4050715603096329, 1.0364333894937898, 0.7721932141886848, 0.5533847195556727, 0.3584587932511938, 0.1763741647808615] # 0.99 24.99 + #v = [1.6448536269514722, 1.2040469600267016, 0.9208229763683788, 0.6971414348463417, 0.5039653672113453, 0.3280721075316511, 0.16184416680396213] # 0.95 24.53 22.97 + #v = [1.4050715603096329, 1.0803193408149558, 0.8416212335729143, 0.643345405392917, 0.4676987991145084, 0.3054807880993974, 0.1509692154967774] # 0.92 24.81 + #v = [1.2815515655446004, 1.0062699858608395, 0.7916386077433746, 0.6084981344998837, 0.4438613119262478, 0.29050677112339396, 0.14372923370582416] # 0.9 24.68 + #v = [1.8807936081512509, 1.2980047163986055, 0.9769954022693226, 0.7341502955472268, 0.5285136765472481, 0.343225833559403, 0.16910470304375366] # 0.97 25.03 + #v = [1.7506860712521692, 1.2496468758017434, 0.9485350408266378, 0.7155233557034365, 0.5162006366043174, 0.3356393360829622, 0.16547334454641704] # 0.96 24.85 23.01 + #v = [1.5547735945968535, 1.1608220210715001, 0.893800631179489, 0.6789921163940618, 0.4918050830048072, 0.3205236191093902, 0.15821711945563585] # 0.94 24.47 + #v = [1.475791028179171, 1.1196635980209986, 0.8674156943957149, 0.6610637542614526, 0.4797170937629045, 0.31299335020578195, 0.15459215234139795] # 0.93 24.85 + #v = [1.5981931399228175, 1.1821583959486879, 0.9072289939325966, 0.6880384454306778, 0.49787602226482025, 0.3242955535308664, 0.160030379970179] # 0.945 24.287 + ##v = [1.6164363711150211, 1.1908453913294612, 0.9126463450304729, 0.6916727602238111, 0.5003095327012462, 0.3258056171348078, 0.1607558311941979] # 0.947 24.293 + #v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.207 + #v = [1.6118251211466303, 1.188665228776879, 0.9112895004060624, 0.690763326564427, 0.4997008778346997, 0.3254280317127771, 0.16057446047146948] # 0.9465 24.30 + #v = [1.6027040905517569, 1.184321770169049, 0.9085808314549837, 0.6889461706317986, 0.4984841229538408, 0.32467299997597887, 0.1602117348657326] # 0.9455 24.293 + #v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.37 22.88 + + # 7B evo start + #v = [1.62129629, 1.18870191, 0.90848106, 0.69108646, 0.50515268, 0.34927819905, 0.14122701] # 22.06 + #v = [1.6143079205628337, 1.1888081407660314, 0.8990131955745421, 0.694373759813679, 0.5083033257326773, 0.3452499746844963, 0.1148939728228951] + #v = [1.614442766030303, 1.189401918639665, 0.8998038168964273, 0.6953094818279475, 0.5073264599048384, 0.3449003790823619, 0.11428378427205564] + + # 13B evo start + #v = [1.6077535089716468, 1.1914902148179205, 0.8999752421085561, 0.6967904489387543, 0.4949093928311768, 0.30920472033044544, 0.15391602735952042] + #v = [1.586363722436466, 1.202610827188916, 0.9003332576346587, 0.6904888715206972, 0.49490974688233724, 0.2971151461329376, 0.15683230810738283] + v = [1.5842247437829478, 1.2037228884260156, 0.900369059187269, 0.6898587137788914, 0.4949097822874533, 0.2959061887131868, 0.15712393618216908] + + # mean evo 7B + 13B + #v = [1.5993337549066253, 1.1965624035328402, 0.9000864380418481, 0.6925840978034195, 0.5011181210961458, 0.32040328389777434, 0.13570386022711237] + + # theoretically optiomal (0.93333) + # v = [1.501085946044025, 1.1331700302595604, 0.8761428492468408, 0.6670160135425023, 0.48373855304610314, 0.3155014472579608, 0.15580024666388428] # 0.9333333333333333 + + + + if seed > 0: + v = np.array(v) + np.random.seed(seed) + v += np.random.randn(7)*scale + print(v.tolist()) + #v[0] += (np.random.randn(1)*0.001)[0] + #v[-1] += (np.random.randn(1)*0.001)[0] + #print(v[0], v[-1]) + v = v.tolist() + values = v + [0]*(256-14) + \ + v[::-1] + + values = torch.Tensor(values) + values[0:7] *= -1 + values = values.sort().values + values /= values.max() + assert values.numel() == 256 + return values def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8): e = exponent_bits @@ -168,7 +234,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) values = [] lst = list(itertools.product([0, 1], repeat=precision_bits)) #for ev in evalues: - bias = 2**(exponent_bits-1)+1 + bias = 2**(exponent_bits-1)-1 for evalue in range(2**(exponent_bits)): for bit_pattern in lst: value = (1 if evalue != 0 else 0) @@ -176,10 +242,10 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) value += pval*(2**-(i+1)) if evalue == 0: # subnormals - value = value*2**-(bias) + value = value*2**-(bias-1) else: # normals - value = value*2**-(evalue-bias-1) + value = value*2**-(evalue-bias-2) values.append(value) if signed: values.append(-value) @@ -502,7 +568,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra out = torch.zeros_like(A, dtype=torch.uint8) if A.device.type != 'cpu': - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32] cblocksize = ct.c_int32(blocksize) prev_device = pre_call(A.device) code = code.to(A.device) @@ -585,7 +651,7 @@ def dequantize_blockwise( if A.device.type != 'cpu': device = pre_call(A.device) code = code.to(A.device) - if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: + if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64, 32]: raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: diff --git a/csrc/kernels.cu b/csrc/kernels.cu index a2691be..8f33161 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2953,6 +2953,8 @@ template __global__ void kQuantizeBlockwise(float * code, ha template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); @@ -2968,8 +2970,6 @@ template __global__ void kQuantizeBlockwise(float * code, ha template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -//template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -//template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); diff --git a/csrc/ops.cu b/csrc/ops.cu index 07ef850..8044c66 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -71,8 +71,8 @@ template void quantizeBlockwise(float * co kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 64) kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); - //else if(blocksize == 32) - //kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 32 and FP4 == 0) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); diff --git a/tests/test_functional.py b/tests/test_functional.py index 54cecca..cd4728e 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -152,7 +152,7 @@ def test_dynamic_quantization(): def test_dynamic_blockwise_quantization(): #print('') - for blocksize in [4096, 2048, 1024, 512, 256, 128, 64]: + for blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]: diffs = [] reldiffs = [] for i in range(100): @@ -167,8 +167,8 @@ def test_dynamic_blockwise_quantization(): relerr = sum(reldiffs)/len(reldiffs) assert abserr < 0.011 assert relerr < 0.018 - #print('randn', blocksize, sum(diffs)/len(diffs)) - #print('randn', blocksize, sum(reldiffs)/len(reldiffs)) + print('randn', blocksize, sum(diffs)/len(diffs)) + print('randn', blocksize, sum(reldiffs)/len(reldiffs)) diffs = [] for i in range(100): @@ -184,8 +184,8 @@ def test_dynamic_blockwise_quantization(): relerr = sum(reldiffs)/len(reldiffs) assert abserr < 0.0035 assert relerr < 0.015 - #print('rand', blocksize, sum(diffs)/len(diffs)) - #print('rand', blocksize, sum(reldiffs)/len(reldiffs)) + print('rand', blocksize, sum(diffs)/len(diffs)) + print('rand', blocksize, sum(reldiffs)/len(reldiffs)) def test_dynamic_blockwise_stochastic_quantization():