From 1efb87d89d1c3fe532eb97847c3b48fd1a8e5d83 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Thu, 3 Nov 2022 19:49:50 -0700 Subject: [PATCH 1/8] Added FP8 quantization map. --- bitsandbytes/functional.py | 34 +++++++++++++++++++++++++ tests/test_functional.py | 51 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index c104ebd..d7e186f 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -6,6 +6,7 @@ import ctypes as ct import operator import random import torch +import itertools from typing import Tuple from torch import Tensor @@ -136,6 +137,39 @@ def create_linear_map(signed=True): return torch.linspace(0.0, 1.0, 256) +def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2): + e = exponent_bits + p = precision_bits + assert e+p == 7 + # the exponent is biased to 2^(e-1) -1 == 0 + evalues = [] + pvalues = [] + for i, val in enumerate(range(-((2**(exponent_bits-1))), 2**(exponent_bits-1), 1)): + evalues.append(2**val) + + + lst = list(itertools.product([0, 1], repeat=precision_bits)) + for bit_pattern in lst: + value = 1 + for i, pval in enumerate(list(bit_pattern)): + value += pval*(2**-(i+1)) + pvalues.append(value) + + assert len(evalues)*len(pvalues) == 128 + values = [] + for ev in evalues: + for pv in pvalues: + values.append(-ev*pv) + values.append(ev*pv) + values.sort() + code = torch.Tensor(values) + code /= code.max() + code[127] = 0 + + return code + + + def create_dynamic_map(signed=True, n=7): """ Creates the dynamic quantiztion map. diff --git a/tests/test_functional.py b/tests/test_functional.py index cf26714..329b270 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2040,3 +2040,54 @@ def test_blockwise_cpu_large(): assert diffs[-1] < 0.011 # print(sum(diffs)/len(diffs)) # print(sum(reldiffs)/len(reldiffs)) + + + +def test_fp8_quant(): + for e_bits in range(1, 7): + p_bits = 7-e_bits + code = F.create_fp8_map(True, e_bits, p_bits).cuda() + + print(e_bits, p_bits) + abserr = [] + relerr = [] + for i in range(100): + A1 = torch.randn(1024, 1024, device="cuda") + C, SC = F.quantize_blockwise(A1, code=code) + A2 = F.dequantize_blockwise(C, SC) + diff = torch.abs(A1 - A2) + reldiff = diff/torch.abs(A1+1e-8) + abserr.append(diff.mean().item()) + relerr.append(reldiff.mean().item()) + #assert diff < 0.0075 + print(sum(abserr)/len(abserr)) + print(sum(relerr)/len(relerr)) + + abserr = [] + relerr = [] + for i in range(100): + A1 = torch.rand(1024, 1024, device="cuda") + C, SC = F.quantize_blockwise(A1, code=code) + A2 = F.dequantize_blockwise(C, SC) + diff = torch.abs(A1 - A2) + reldiff = diff/torch.abs(A1+1e-8) + abserr.append(diff.mean().item()) + relerr.append(reldiff.mean().item()) + #assert diff < 0.0075 + print(sum(abserr)/len(abserr)) + print(sum(relerr)/len(relerr)) + + abserr = [] + relerr = [] + for i in range(100): + A1 = torch.randn(1024, 1024, device="cuda") + C, SC = F.quantize_blockwise(A1) + A2 = F.dequantize_blockwise(C, SC) + diff = torch.abs(A1 - A2) + reldiff = diff/torch.abs(A1+1e-8) + abserr.append(diff.mean().item()) + relerr.append(reldiff.mean().item()) + #assert diff < 0.0075 + print(3, sum(abserr)/len(abserr)) + print(3, sum(relerr)/len(relerr)) + From caf1832526e4ad54ae8fe8e947f19ed690f35a40 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 6 Nov 2022 11:47:54 -0800 Subject: [PATCH 2/8] Added k-bit linear quantization. --- bitsandbytes/functional.py | 14 ++++++++--- tests/test_functional.py | 50 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index d7e186f..65eccf2 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -130,11 +130,17 @@ class Cusparse_Context(object): return cls._instance -def create_linear_map(signed=True): - if signed: - return torch.linspace(-1.0, 1.0, 256) +def create_linear_map(signed=True, bits=8): + sign = (-1.0 if signed else 0.0) + + values = torch.linspace(sign, 1.0, 2**bits) + gap = 256 - values.numel() + if gap == 0: + return values else: - return torch.linspace(0.0, 1.0, 256) + l = values.numel()//2 + #return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist()) + return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist()) def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2): diff --git a/tests/test_functional.py b/tests/test_functional.py index 329b270..494bf51 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2091,3 +2091,53 @@ def test_fp8_quant(): print(3, sum(abserr)/len(abserr)) print(3, sum(relerr)/len(relerr)) + +def test_few_bit_quant(): + + for bits in range(2, 9): + code = F.create_linear_map(True, bits=bits).cuda() + assert code.numel() == 256 + print(bits) + for i in range(100): + + values = torch.randn(1, 24, device='cuda') + values /= values.abs().max() + #values[values.abs() < 1e-6] += 1e-5 + + q1 = [] + v1 = [] + for v in values[0]: + idx = torch.abs(v-code).argmin() + q1.append(idx.item()) + v1.append(code[idx].item()) + + q1 = torch.Tensor(q1).cuda() + v1 = torch.Tensor(v1).cuda() + + q2, S2 = F.quantize(values, code=code) + v2 = F.dequantize(q2, S2) + + idx = torch.isclose(q1.int(), q2.int()) + if idx.sum(): + # some weird cases + err1 = torch.abs(v1-values).mean() + err2 = torch.abs(v2-values).mean() + assert err2 <= err1 + + else: + torch.testing.assert_allclose(q1, q2) + + #print(e_bits, p_bits) + #abserr = [] + #relerr = [] + #for i in range(100): + # A1 = torch.randn(1024, 1024, device="cuda") + # C, SC = F.quantize_blockwise(A1, code=code) + # A2 = F.dequantize_blockwise(C, SC) + # diff = torch.abs(A1 - A2) + # reldiff = diff/torch.abs(A1+1e-8) + # abserr.append(diff.mean().item()) + # relerr.append(reldiff.mean().item()) + # #assert diff < 0.0075 + #print(sum(abserr)/len(abserr)) + #print(sum(relerr)/len(relerr)) From 98cbc4bc4f15f5c094cd8575ddb0380a19516099 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 6 Nov 2022 11:59:37 -0800 Subject: [PATCH 3/8] Added k-bit fp8 map. --- bitsandbytes/functional.py | 16 +++++--- tests/test_functional.py | 76 ++++++++++++++++++-------------------- 2 files changed, 46 insertions(+), 46 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 65eccf2..ff48b7f 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -143,14 +143,15 @@ def create_linear_map(signed=True, bits=8): return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist()) -def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2): +def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8): e = exponent_bits p = precision_bits - assert e+p == 7 + has_sign = 1 if signed else 0 + assert e+p == total_bits-has_sign # the exponent is biased to 2^(e-1) -1 == 0 evalues = [] pvalues = [] - for i, val in enumerate(range(-((2**(exponent_bits-1))), 2**(exponent_bits-1), 1)): + for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)): evalues.append(2**val) @@ -161,12 +162,17 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2): value += pval*(2**-(i+1)) pvalues.append(value) - assert len(evalues)*len(pvalues) == 128 + assert len(evalues)*len(pvalues) == 2**(total_bits-has_sign) values = [] for ev in evalues: for pv in pvalues: - values.append(-ev*pv) + if signed: + values.append(-ev*pv) values.append(ev*pv) + if total_bits < 8: + gap = 256 - len(values) + for i in range(gap): + values.append(0) values.sort() code = torch.Tensor(values) code /= code.max() diff --git a/tests/test_functional.py b/tests/test_functional.py index 494bf51..bd4dafe 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -11,7 +11,7 @@ import bitsandbytes as bnb from bitsandbytes import functional as F torch.set_printoptions( - precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000 + precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000 ) k = 20 @@ -2095,49 +2095,43 @@ def test_fp8_quant(): def test_few_bit_quant(): for bits in range(2, 9): - code = F.create_linear_map(True, bits=bits).cuda() - assert code.numel() == 256 - print(bits) - for i in range(100): + for method in ['linear', 'fp8']: + code = None + if method == 'linear': + code = F.create_linear_map(True, bits=bits).cuda() + elif method == 'fp8': + ebits = math.ceil(bits/2) + pbits = bits-ebits-1 + code = F.create_fp8_map(True, ebits, pbits, bits).cuda() + print(ebits, pbits, bits) + print(code) + assert code.numel() == 256 + print(bits) + for i in range(10): - values = torch.randn(1, 24, device='cuda') - values /= values.abs().max() - #values[values.abs() < 1e-6] += 1e-5 + values = torch.randn(1, 32, device='cuda') + values /= values.abs().max() + #values[values.abs() < 1e-6] += 1e-5 - q1 = [] - v1 = [] - for v in values[0]: - idx = torch.abs(v-code).argmin() - q1.append(idx.item()) - v1.append(code[idx].item()) + q1 = [] + v1 = [] + for v in values[0]: + idx = torch.abs(v-code).argmin() + q1.append(idx.item()) + v1.append(code[idx].item()) - q1 = torch.Tensor(q1).cuda() - v1 = torch.Tensor(v1).cuda() + q1 = torch.Tensor(q1).cuda() + v1 = torch.Tensor(v1).cuda() - q2, S2 = F.quantize(values, code=code) - v2 = F.dequantize(q2, S2) + q2, S2 = F.quantize(values, code=code) + v2 = F.dequantize(q2, S2) - idx = torch.isclose(q1.int(), q2.int()) - if idx.sum(): - # some weird cases - err1 = torch.abs(v1-values).mean() - err2 = torch.abs(v2-values).mean() - assert err2 <= err1 + idx = torch.isclose(q1.int(), q2.int()) + if idx.sum(): + # some weird cases + err1 = torch.abs(v1-values).mean() + err2 = torch.abs(v2-values).mean() + assert err2 <= err1 - else: - torch.testing.assert_allclose(q1, q2) - - #print(e_bits, p_bits) - #abserr = [] - #relerr = [] - #for i in range(100): - # A1 = torch.randn(1024, 1024, device="cuda") - # C, SC = F.quantize_blockwise(A1, code=code) - # A2 = F.dequantize_blockwise(C, SC) - # diff = torch.abs(A1 - A2) - # reldiff = diff/torch.abs(A1+1e-8) - # abserr.append(diff.mean().item()) - # relerr.append(reldiff.mean().item()) - # #assert diff < 0.0075 - #print(sum(abserr)/len(abserr)) - #print(sum(relerr)/len(relerr)) + else: + torch.testing.assert_allclose(q1, q2) From 2f2063bac212bcd6a515a88a12a9530b5730dabe Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 6 Nov 2022 13:05:25 -0800 Subject: [PATCH 4/8] Added k<256 quantile estimate. --- bitsandbytes/functional.py | 61 +++++++++++++++++++++++--------------- tests/test_functional.py | 43 +++++++++++++++++++++++---- 2 files changed, 74 insertions(+), 30 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index ff48b7f..076414d 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -182,7 +182,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) -def create_dynamic_map(signed=True, n=7): +def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): """ Creates the dynamic quantiztion map. @@ -203,28 +203,32 @@ def create_dynamic_map(signed=True, n=7): # these are additional items that come from the case # where all the exponent bits are zero and no # indicator bit is present - additional_items = 2 ** (7 - n) - 1 + non_sign_bits = total_bits - (1 if signed else 0) + additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1 if not signed: additional_items = 2 * additional_items - for i in range(n): - fraction_items = ( - 2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1 - ) + for i in range(max_exponent_bits): + fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1)) boundaries = torch.linspace(0.1, 1, fraction_items) means = (boundaries[:-1] + boundaries[1:]) / 2.0 - data += ((10 ** (-(n - 1) + i)) * means).tolist() + data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() if signed: - data += (-(10 ** (-(n - 1) + i)) * means).tolist() + data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() - if additional_items > 0: - boundaries = torch.linspace(0.1, 1, additional_items + 1) - means = (boundaries[:-1] + boundaries[1:]) / 2.0 - data += ((10 ** (-(n - 1) + i)) * means).tolist() - if signed: - data += (-(10 ** (-(n - 1) + i)) * means).tolist() + if additional_items > 0: + boundaries = torch.linspace(0.1, 1, additional_items + 1) + means = (boundaries[:-1] + boundaries[1:]) / 2.0 + data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() + if signed: + data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() data.append(0) data.append(1.0) + + gap = 256 - len(data) + for i in range(gap): + data.append(0) + data.sort() return Tensor(data) @@ -371,9 +375,7 @@ def nvidia_transform( return out, new_state -def estimate_quantiles( - A: Tensor, out: Tensor = None, offset: float = 1 / 512 -) -> Tensor: +def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor: ''' Estimates 256 equidistant quantiles on the input tensor eCDF. @@ -393,25 +395,36 @@ def estimate_quantiles( out : torch.Tensor Tensor with the 256 estimated quantiles. offset : float - The offset for the first and last quantile from 0 and 1. Default: 1/512 + The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles) + num_quantiles : int + The number of equally spaced quantiles. Returns ------- torch.Tensor: The 256 quantiles in float32 datatype. ''' + if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.') + if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}") + if num_quantiles < 256 and offset == 1/(512): + # override default arguments + offset = 1/(2*num_quantiles) + if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device) is_on_gpu([A, out]) + device = pre_call(A.device) if A.dtype == torch.float32: - lib.cestimate_quantiles_fp32( - get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()) - ) + lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) elif A.dtype == torch.float16: - lib.cestimate_quantiles_fp16( - get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()) - ) + lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) else: raise NotImplementedError(f"Not supported data type {A.dtype}") + post_call(device) + + if num_quantiles < 256: + idx = torch.linspace(0, 255, num_quantiles).long().to(A.device) + out = out[idx] + return out diff --git a/tests/test_functional.py b/tests/test_functional.py index bd4dafe..99885da 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -6,9 +6,11 @@ from itertools import product import einops import pytest import torch +import numpy as np import bitsandbytes as bnb from bitsandbytes import functional as F +from scipy.stats import norm torch.set_printoptions( precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000 @@ -2094,8 +2096,12 @@ def test_fp8_quant(): def test_few_bit_quant(): + print('') for bits in range(2, 9): - for method in ['linear', 'fp8']: + print('='*30, bits, '='*30) + for method in ['linear', 'fp8', 'dynamic', 'quantile']: + abserrs = [] + relerrs = [] code = None if method == 'linear': code = F.create_linear_map(True, bits=bits).cuda() @@ -2103,10 +2109,21 @@ def test_few_bit_quant(): ebits = math.ceil(bits/2) pbits = bits-ebits-1 code = F.create_fp8_map(True, ebits, pbits, bits).cuda() - print(ebits, pbits, bits) - print(code) + elif method == 'dynamic': + code = F.create_dynamic_map(True, bits-0, bits).cuda() + elif method == 'quantile': + values = torch.randn(2048, 2048, device='cuda') + q = F.estimate_quantiles(values, offset= 1/(2*(2**bits)), num_quantiles=2**bits) + gap = 256-q.numel() + q = q.tolist() + for i in range(gap): + q.append(0) + q = torch.Tensor(q).cuda() + + q /= q.abs().max() + code, idx = torch.sort(q) + print(method, (code==0).sum()) assert code.numel() == 256 - print(bits) for i in range(10): values = torch.randn(1, 32, device='cuda') @@ -2127,11 +2144,25 @@ def test_few_bit_quant(): v2 = F.dequantize(q2, S2) idx = torch.isclose(q1.int(), q2.int()) + err2 = torch.abs(v2-values) + abserrs.append(err2.mean().item()) + relerrs.append((err2/(1e-10+values).abs()).mean().item()) if idx.sum(): # some weird cases err1 = torch.abs(v1-values).mean() - err2 = torch.abs(v2-values).mean() - assert err2 <= err1 + assert err2.mean() <= err1 else: torch.testing.assert_allclose(q1, q2) + print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs)) + + +def test_kbit_quantile_estimation(): + for i in range(100): + data = torch.randn(1024, 1024, device='cuda') + for bits in range(2, 9): + p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits) + val1 = torch.Tensor(norm.ppf(p)).cuda() + val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits) + err = torch.abs(val1-val2).mean() + assert err < 0.035 From 6bc2b992be0bb7511ea881f8ebbbd2ba7f1b5109 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 6 Nov 2022 16:27:48 -0800 Subject: [PATCH 5/8] Added blocksizes 2048, 1024, and 512 to blockwise quant. --- bitsandbytes/cextension.py | 11 ++++- bitsandbytes/functional.py | 20 ++++----- csrc/kernels.cu | 22 +++++++--- csrc/ops.cu | 33 ++++++++++---- csrc/ops.cuh | 2 +- csrc/pythonInterface.c | 12 ++--- tests/test_functional.py | 90 +++++++++++++++++++------------------- 7 files changed, 112 insertions(+), 78 deletions(-) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 8125202..ead8502 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -52,8 +52,13 @@ class CUDASetup(object): self.add_log_entry('python setup.py install') def initialize(self): - self.cuda_setup_log = [] + self.has_printed = False self.lib = None + self.run_cuda_setup() + + def run_cuda_setup(self): + self.initialized = True + self.cuda_setup_log = [] from .cuda_setup.main import evaluate_cuda_setup binary_name, cudart_path, cuda, cc, cuda_version_string = evaluate_cuda_setup() @@ -89,7 +94,9 @@ class CUDASetup(object): else: self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...") self.lib = ct.cdll.LoadLibrary(binary_path) - except: + print(self.lib) + except Exception as ex: + self.add_log_entry(str(ex)) self.print_log_stack() def add_log_entry(self, msg, is_warning=False): diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 076414d..49d4db1 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -130,10 +130,10 @@ class Cusparse_Context(object): return cls._instance -def create_linear_map(signed=True, bits=8): +def create_linear_map(signed=True, total_bits=8): sign = (-1.0 if signed else 0.0) - values = torch.linspace(sign, 1.0, 2**bits) + values = torch.linspace(sign, 1.0, 2**total_bits) gap = 256 - values.numel() if gap == 0: return values @@ -457,6 +457,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra The quantization state to undo the quantization. """ + if code is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) @@ -474,8 +475,11 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra out = torch.zeros_like(A, dtype=torch.uint8) if A.device.type != 'cpu': + assert blocksize in [4096, 2048, 1024, 512] is_on_gpu([code, A, absmax, out, rand]) + cblocksize = ct.c_int32(blocksize) if rand is not None: + assert blocksize==4096 assert rand.numel() >= 1024 rand_offset = random.randint(0, 1023) if A.dtype == torch.float32: @@ -483,18 +487,14 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra elif A.dtype == torch.float16: lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel())) else: - raise ValueError( - f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" - ) + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") else: if A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel())) + lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) elif A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel())) + lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) else: - raise ValueError( - f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" - ) + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") else: # cpu assert rand is None diff --git a/csrc/kernels.cu b/csrc/kernels.cu index f01b4e1..9d9653c 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -428,16 +428,16 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c } template -__launch_bounds__(TH, 4) +//__launch_bounds__(TH, 4) __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n) { const int n_full = gridDim.x * BLOCK_SIZE; int valid_items = 0; const int base_idx = (blockIdx.x * BLOCK_SIZE); - T vals[NUM]; - float rand_vals[NUM]; - unsigned char qvals[NUM]; + T vals[NUM_PER_TH]; + float rand_vals[NUM_PER_TH]; + unsigned char qvals[NUM_PER_TH]; //float local_abs_max = -FLT_MAX; float local_abs_max = 0.0f; int local_rand_idx = 0; @@ -517,8 +517,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c int valid_items = 0; const int base_idx = (blockIdx.x * BLOCK_SIZE); - T vals[NUM]; - unsigned char qvals[NUM]; + T vals[NUM_PER_TH]; + unsigned char qvals[NUM_PER_TH]; float local_abs_max = -FLT_MAX; typedef cub::BlockLoad LoadChar; @@ -2791,11 +2791,21 @@ template __global__ void kQuantizeBlockwise(float * code, half template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); diff --git a/csrc/ops.cu b/csrc/ops.cu index e49c94b..b121fc2 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -50,11 +50,23 @@ void dequantize(float *code, unsigned char *A, float *out, int n) CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n) +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) { - int num_blocks = n/4096; - num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + int num_blocks = n/blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + if(STOCHASTIC == 1) + assert(blocksize == 4096); + + if(blocksize == 4096) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 2048) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 1024) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 512) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -66,6 +78,11 @@ template void dequantizeBlockwise(float *code, unsigned char *A, flo kDequantizeBlockwise<<>>(code, A, absmax, out, n); else if(blocksize == 2048) kDequantizeBlockwise<<>>(code, A, absmax, out, n); + else if(blocksize == 1024) + kDequantizeBlockwise<<>>(code, A, absmax, out, n); + else if(blocksize == 512) + kDequantizeBlockwise<<>>(code, A, absmax, out, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -659,10 +676,10 @@ template void transformRowToFormat(char * A, char *out, int rows, template void estimateQuantiles(half *A, float *code, float offset, int n); template void estimateQuantiles(float *A, float *code, float offset, int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index acfdb06..66e3843 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -128,7 +128,7 @@ template void estimateQuantiles(T *A, float *code, float offset, in void quantize(float *code, float *A, unsigned char *out, int n); void dequantize(float *code, unsigned char *A, float *out, int n); -template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); template void optimizer32bit(T* g, T* p, diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 58e26a9..5bac30e 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -75,10 +75,10 @@ MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32) void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } -void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, n); } -void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, n); } -void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, n); } -void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, n); } +void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, 4096, n); } +void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, 4096, n); } void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } \ void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } @@ -140,8 +140,8 @@ extern "C" void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); } void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); } - void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, n); } - void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, n); } + void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, n); } void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, n); } diff --git a/tests/test_functional.py b/tests/test_functional.py index 99885da..b525dff 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -151,30 +151,41 @@ def test_dynamic_quantization(): def test_dynamic_blockwise_quantization(): - diffs = [] - reldiffs = [] - for i in range(100): - A1 = torch.randn(1024, 1024, device="cuda") - C, S = F.quantize_blockwise(A1) - A2 = F.dequantize_blockwise(C, S) - diff = torch.abs(A1 - A2) - reldiff = diff / torch.abs(A1 + 1e-8) - diffs.append(diff.mean().item()) - reldiffs.append(reldiff.mean().item()) - assert diffs[-1] < 0.011 - # print(sum(diffs)/len(diffs)) - # print(sum(reldiffs)/len(reldiffs)) + #print('') + for blocksize in [4096, 2048, 1024, 512]: + diffs = [] + reldiffs = [] + for i in range(100): + A1 = torch.randn(1024, 1024, device="cuda") + C, S = F.quantize_blockwise(A1) + A2 = F.dequantize_blockwise(C, S) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + abserr = sum(diffs)/len(diffs) + relerr = sum(reldiffs)/len(reldiffs) + assert abserr < 0.011 + assert relerr < 0.018 + #print('randn', blocksize, sum(diffs)/len(diffs)) + #print('randn', blocksize, sum(reldiffs)/len(reldiffs)) - diffs = [] - for i in range(100): - A1 = torch.rand(1024, 1024, device="cuda") - C, S = F.quantize_blockwise(A1) - A2 = F.dequantize_blockwise(C, S) - diff = torch.abs(A1 - A2).mean().item() - assert diff < 0.0033 - diffs.append(diff) - torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) - # print(sum(diffs)/len(diffs)) + diffs = [] + for i in range(100): + A1 = torch.rand(1024, 1024, device="cuda") + C, S = F.quantize_blockwise(A1) + A2 = F.dequantize_blockwise(C, S) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) + abserr = sum(diffs)/len(diffs) + relerr = sum(reldiffs)/len(reldiffs) + assert abserr < 0.0035 + assert relerr < 0.015 + #print('rand', blocksize, sum(diffs)/len(diffs)) + #print('rand', blocksize, sum(reldiffs)/len(reldiffs)) def test_dynamic_blockwise_stochastic_quantization(): @@ -1618,17 +1629,6 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): # print(time.time() - t0) -def test_layout(): - a1 = torch.rand(16, 64, device="cuda", dtype=torch.float16) - a1 = torch.arange(16 * 64, device="cuda").reshape(16, 64).byte() - a2, s2 = F.transform(a1, "col_turing") - print(a2.shape) - - print(a1.flatten()[8 * 64 : 8 * 64 + 32]) - for i in range(4): - print(a2.flatten()[i * 8 * 32 : i * 8 * 32 + 32], 0) - - def test_coo2csr(): threshold = 1 A = torch.randn(128, 128).half().cuda() @@ -2062,8 +2062,8 @@ def test_fp8_quant(): abserr.append(diff.mean().item()) relerr.append(reldiff.mean().item()) #assert diff < 0.0075 - print(sum(abserr)/len(abserr)) - print(sum(relerr)/len(relerr)) + #print(sum(abserr)/len(abserr)) + #print(sum(relerr)/len(relerr)) abserr = [] relerr = [] @@ -2076,8 +2076,8 @@ def test_fp8_quant(): abserr.append(diff.mean().item()) relerr.append(reldiff.mean().item()) #assert diff < 0.0075 - print(sum(abserr)/len(abserr)) - print(sum(relerr)/len(relerr)) + #print(sum(abserr)/len(abserr)) + #print(sum(relerr)/len(relerr)) abserr = [] relerr = [] @@ -2090,21 +2090,21 @@ def test_fp8_quant(): abserr.append(diff.mean().item()) relerr.append(reldiff.mean().item()) #assert diff < 0.0075 - print(3, sum(abserr)/len(abserr)) - print(3, sum(relerr)/len(relerr)) + #print(3, sum(abserr)/len(abserr)) + #print(3, sum(relerr)/len(relerr)) def test_few_bit_quant(): - print('') + #print('') for bits in range(2, 9): - print('='*30, bits, '='*30) + #print('='*30, bits, '='*30) for method in ['linear', 'fp8', 'dynamic', 'quantile']: abserrs = [] relerrs = [] code = None if method == 'linear': - code = F.create_linear_map(True, bits=bits).cuda() + code = F.create_linear_map(True, total_bits=bits).cuda() elif method == 'fp8': ebits = math.ceil(bits/2) pbits = bits-ebits-1 @@ -2122,7 +2122,7 @@ def test_few_bit_quant(): q /= q.abs().max() code, idx = torch.sort(q) - print(method, (code==0).sum()) + #print(method, (code==0).sum()) assert code.numel() == 256 for i in range(10): @@ -2154,7 +2154,7 @@ def test_few_bit_quant(): else: torch.testing.assert_allclose(q1, q2) - print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs)) + #print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs)) def test_kbit_quantile_estimation(): From e0e697b150ba830d19a2f5fbeaf22f1349eddbe3 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 6 Nov 2022 16:36:31 -0800 Subject: [PATCH 6/8] Fixed blockwise test and logic. --- bitsandbytes/functional.py | 10 ++++------ tests/test_functional.py | 10 +++++----- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 49d4db1..aef6971 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -466,7 +466,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra if absmax is None: n = A.numel() - blocksize = (blocksize if A.device.type == 'cpu' else 4096) + blocksize = (blocksize if A.device.type == 'cuda' else 4096) blocks = n // blocksize blocks += 1 if n % blocksize > 0 else 0 absmax = torch.zeros((blocks,), device=A.device) @@ -550,17 +550,15 @@ def dequantize_blockwise( if A.device.type != 'cpu': - if blocksize not in [2048, 4096]: - raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]") + if blocksize not in [2048, 4096, 1024, 512]: + raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512]") is_on_gpu([A, out]) if out.dtype == torch.float32: lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) elif out.dtype == torch.float16: lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) else: - raise ValueError( - f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" - ) + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") else: lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) diff --git a/tests/test_functional.py b/tests/test_functional.py index b525dff..4642b16 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -157,8 +157,8 @@ def test_dynamic_blockwise_quantization(): reldiffs = [] for i in range(100): A1 = torch.randn(1024, 1024, device="cuda") - C, S = F.quantize_blockwise(A1) - A2 = F.dequantize_blockwise(C, S) + C, S = F.quantize_blockwise(A1, blocksize=blocksize) + A2 = F.dequantize_blockwise(C, S, blocksize=blocksize) diff = torch.abs(A1 - A2) reldiff = diff / torch.abs(A1 + 1e-8) diffs.append(diff.mean().item()) @@ -173,13 +173,13 @@ def test_dynamic_blockwise_quantization(): diffs = [] for i in range(100): A1 = torch.rand(1024, 1024, device="cuda") - C, S = F.quantize_blockwise(A1) - A2 = F.dequantize_blockwise(C, S) + C, S = F.quantize_blockwise(A1, blocksize=blocksize) + A2 = F.dequantize_blockwise(C, S, blocksize=blocksize) diff = torch.abs(A1 - A2) reldiff = diff / torch.abs(A1 + 1e-8) diffs.append(diff.mean().item()) reldiffs.append(reldiff.mean().item()) - torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) + #torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) abserr = sum(diffs)/len(diffs) relerr = sum(reldiffs)/len(reldiffs) assert abserr < 0.0035 From 62a333ac40f157e69c4bb86f30ac06b41ca4ff34 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 6 Nov 2022 17:17:51 -0800 Subject: [PATCH 7/8] Added pre/post calls do quantize_blockwise. --- bitsandbytes/functional.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index aef6971..6278db9 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -458,6 +458,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra """ + prev_device = pre_call(A.device) if code is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) @@ -479,6 +480,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra is_on_gpu([code, A, absmax, out, rand]) cblocksize = ct.c_int32(blocksize) if rand is not None: + is_on_gpu([code, A, out, absmax, rand]) assert blocksize==4096 assert rand.numel() >= 1024 rand_offset = random.randint(0, 1023) @@ -489,6 +491,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") else: + is_on_gpu([code, A, out, absmax]) if A.dtype == torch.float32: lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) elif A.dtype == torch.float16: @@ -499,6 +502,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra # cpu assert rand is None lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) + post_call(A.device) return out, (absmax, code) @@ -537,6 +541,7 @@ def dequantize_blockwise( Dequantized tensor (default: float32) """ assert quant_state is not None or absmax is not None + device = pre_call(A.device) if code is None and quant_state is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) @@ -561,6 +566,7 @@ def dequantize_blockwise( raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") else: lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) + post_call(A.device) return out From 08fa2e7b01dda8959a930295de9829516f8c77bc Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 7 Nov 2022 18:06:18 -0800 Subject: [PATCH 8/8] Fixed bug in cpu quant; faster GPU dequant. --- bitsandbytes/cextension.py | 1 - bitsandbytes/functional.py | 22 ++++++++++++---------- csrc/kernels.cu | 28 +++++++++++++++------------- csrc/kernels.cuh | 2 +- tests/test_functional.py | 16 ++++++++++++++++ 5 files changed, 44 insertions(+), 25 deletions(-) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index ead8502..264e899 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -94,7 +94,6 @@ class CUDASetup(object): else: self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...") self.lib = ct.cdll.LoadLibrary(binary_path) - print(self.lib) except Exception as ex: self.add_log_entry(str(ex)) self.print_log_stack() diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 6278db9..fffbecf 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -458,16 +458,13 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra """ - prev_device = pre_call(A.device) if code is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) code = name2qmap["dynamic"] - code = code.to(A.device) if absmax is None: n = A.numel() - blocksize = (blocksize if A.device.type == 'cuda' else 4096) blocks = n // blocksize blocks += 1 if n % blocksize > 0 else 0 absmax = torch.zeros((blocks,), device=A.device) @@ -477,8 +474,9 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra if A.device.type != 'cpu': assert blocksize in [4096, 2048, 1024, 512] - is_on_gpu([code, A, absmax, out, rand]) cblocksize = ct.c_int32(blocksize) + prev_device = pre_call(A.device) + code = code.to(A.device) if rand is not None: is_on_gpu([code, A, out, absmax, rand]) assert blocksize==4096 @@ -498,11 +496,12 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + post_call(A.device) else: # cpu + code = code.cpu() assert rand is None lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) - post_call(A.device) return out, (absmax, code) @@ -541,32 +540,35 @@ def dequantize_blockwise( Dequantized tensor (default: float32) """ assert quant_state is not None or absmax is not None - device = pre_call(A.device) if code is None and quant_state is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) code = name2qmap["dynamic"] - code = code.to(A.device) if out is None: out = torch.zeros_like(A, dtype=torch.float32) if quant_state is None: quant_state = (absmax, code) + else: + absmax, code = quant_state if A.device.type != 'cpu': + device = pre_call(A.device) + code = code.to(A.device) if blocksize not in [2048, 4096, 1024, 512]: raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512]") is_on_gpu([A, out]) if out.dtype == torch.float32: - lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) elif out.dtype == torch.float16: - lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + post_call(A.device) else: + code = code.cpu() lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) - post_call(A.device) return out diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 9d9653c..4c750d1 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -510,7 +510,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float } template -__global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, T *out, const int n) +__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n) { const int n_full = gridDim.x * BLOCK_SIZE; @@ -526,10 +526,11 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c __shared__ typename LoadChar::TempStorage loadchar; __shared__ typename StoreT::TempStorage storet; - __shared__ float smem_code[256]; + //__shared__ float smem_code[256]; + //float local_code[16]; - if(threadIdx.x < 256) - smem_code[threadIdx.x] = code[threadIdx.x]; + //if(threadIdx.x < 256) + //smem_code[threadIdx.x] = code[threadIdx.x]; for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) { @@ -539,9 +540,10 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c __syncthreads(); LoadChar(loadchar).Load(&(A[i]), qvals, valid_items, 128); + // load code through read-only cache via __ldg #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH; j++) - vals[j] = smem_code[qvals[j]]*local_abs_max; + vals[j] = __ldg(&code[qvals[j]])*local_abs_max; __syncthreads(); StoreT(storet).Store(&(out[i]), vals, valid_items); @@ -2798,14 +2800,14 @@ template __global__ void kQuantizeBlockwise(float * code, flo template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index bdf61b2..cca983b 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -15,7 +15,7 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c __global__ void kDequantize(float *code, unsigned char *A, float *out, const int n); template __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, T *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n); template __global__ void kPreconditionOptimizer32bit2State(T* g, T* p, diff --git a/tests/test_functional.py b/tests/test_functional.py index 4642b16..d36dfc1 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2166,3 +2166,19 @@ def test_kbit_quantile_estimation(): val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits) err = torch.abs(val1-val2).mean() assert err < 0.035 + + +def test_bench_dequantization(): + a = torch.rand(1024, 1024, device='cuda').half() + qa, SA = F.quantize_blockwise(a) + + max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000 + #print(max_theoretical_mu) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + F.dequantize_blockwise(qa, SA, blocksize=2048) + torch.cuda.synchronize() + #print((time.time()-t0)/1e6) +