diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 95a7c4f..371f85c 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -168,7 +168,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) values = [] lst = list(itertools.product([0, 1], repeat=precision_bits)) #for ev in evalues: - bias = 2**(exponent_bits-1)-1 + bias = 2**(exponent_bits-1) for evalue in range(2**(exponent_bits)): for bit_pattern in lst: value = (1 if evalue != 0 else 0) @@ -176,10 +176,10 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) value += pval*(2**-(i+1)) if evalue == 0: # subnormals - value = value*2**-(bias-1) + value = value*2**-(bias) else: # normals - value = value*2**-(evalue-bias-2) + value = value*2**-(evalue-bias-1) values.append(value) if signed: values.append(-value) diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index edc595a..221b5f7 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -2,4 +2,4 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .modules import Int8Params, Linear8bitLt, StableEmbedding +from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index a623bf1..4746a4a 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -10,6 +10,7 @@ from torch import Tensor, device, dtype, nn import bitsandbytes as bnb from bitsandbytes.optim import GlobalOptimManager +from bitsandbytes.utils import OutlierTracer, find_outlier_dims T = TypeVar("T", bound="torch.nn.Module") @@ -133,6 +134,83 @@ class Embedding(torch.nn.Embedding): return emb +class OutlierAwareLinear(nn.Linear): + def __init__(self, input_features, output_features, bias=True): + super().__init__(input_features, output_features, bias) + self.outlier_dim = None + self.is_quantized = False + + def forward_with_outliers(self, x, outlier_idx): + raise NotImplementedError('Please override the `forward_with_outliers(self, x, outlier_idx)` function') + + def quantize_weight(self, w, outlier_idx): + raise NotImplementedError('Please override the `quantize_weights(self, w, outlier_idx)` function') + + def forward(self, x): + if self.outlier_dim is None: + tracer = OutlierTracer.get_instance() + if not tracer.is_initialized(): + print('Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer') + outlier_idx = tracer.get_outliers(self.weight) + #print(outlier_idx, tracer.get_hvalue(self.weight)) + self.outlier_dim = outlier_idx + + if not self.is_quantized: + w = self.quantize_weight(self.weight, self.outlier_dim) + self.weight.data.copy_(w) + self.is_quantized = True + + return self.forward_with_outliers(x, self.outlier_dim) + + +class Fake4bitLinear(OutlierAwareLinear): + def __init__(self, input_features, output_features, bias=True, codebook=bnb.functional.create_fp8_map(True, 3, 0, total_bits=4)): + super().__init__(input_features, output_features, bias) + self.codebook = codebook + + def quantize_weight(self, w, outlier_idx): + if outlier_idx.numel() > 0: + subw = w[:, outlier_idx].clone() + w[:, outlier_idx] = 0 + wdtype = w.dtype + code = self.codebook.to(w.device) + cw, state = bnb.functional.quantize_blockwise(w, code=code, blocksize=64) + w = bnb.functional.dequantize_blockwise(cw, state, blocksize=64) + w = w.to(wdtype) + if outlier_idx.numel() > 0: + w[:, outlier_idx] = subw + self.is_quantized = True + return w + + def forward_with_outliers(self, x, outlier_idx): + dims = torch.abs(x> 4).sum(dim=list(range(len(x.shape)-1))) + outlier_idx2 = torch.where(dims > 0)[0] + outlier_idx = torch.cat([outlier_idx, outlier_idx2]).unique() + n = x.shape[-1] + idx = torch.arange(n, device=x.device) + idx[outlier_idx] = -1 + inverse_idx = torch.where(idx >= 0)[0] + if outlier_idx.numel() > 0: + subx = x[..., outlier_idx].clone() + #print(1, subx, 1) + #x[..., outlier_idx] = 0 + inverse_x = x[...,inverse_idx] + xdtype = x.dtype + #code = bnb.functional.create_fp8_map(True, 4-3, 2, 4).to(x.device) + #code = bnb.functional.create_quantile_map(x, 4).to(x.device) + code = bnb.functional.create_dynamic_map(True, total_bits=4.0).to(x.device) + c, state = bnb.functional.quantize_blockwise(inverse_x, code=code, blocksize=64) + inverse_x = bnb.functional.dequantize_blockwise(c, state, blocksize=64) + #c, state = bnb.functional.quantize_blockwise(x, code=code, blocksize=64) + #x = bnb.functional.dequantize_blockwise(c, state, blocksize=64) + x = x.to(xdtype) + x[..., inverse_idx] = inverse_x.to(x.dtype) + #if outlier_idx.numel() > 0: + #x[..., outlier_idx] = subx + + return torch.nn.functional.linear(x, self.weight, self.bias) + + class Int8Params(torch.nn.Parameter): def __new__( diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 1cd90e3..30d9e10 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -1,7 +1,143 @@ import shlex import subprocess +import torch from typing import Tuple +def outlier_hook(module, input): + assert isinstance(module, torch.nn.Linear) + tracer = OutlierTracer.get_instance() + hvalue = tracer.get_hvalue(module.weight) + if hvalue not in tracer.hvalue2outlier_idx: + outlier_idx = find_outlier_dims(module.weight) + tracer.outliers.append(outlier_idx) + tracer.hvalues.append(hvalue) + if len(tracer.outliers) > 1: + # assign the current layer the outlier idx found from the weight + # of the previous linear layer + if tracer.outliers[-1].numel() > 0: + assert tracer.outliers[-1].max() < module.weight.shape[1] + tracer.hvalue2outlier_idx[hvalue] = tracer.outliers[-1] + + else: + # first layer, we cannot use the weight for outlier detection + # we follow a mixed approach: + # (1) zscore test of std of hidden dimension + # (2) magnitude > 6 test + merged = input[0].view(-1, input[0].shape[-1]) + # (1) zscore test of std of hidden dimension + outlier_idx = find_outlier_dims(merged, reduction_dim=1, zscore=3) + # (2) magnitude > 6 test + dims = (torch.abs(input[0])> 6).sum(dim=list(range(len(input[0].shape)-1))) + outlier_idx2 = torch.where(dims > 0)[0] + outlier_idx = torch.cat([outlier_idx, outlier_idx2]).unique() + tracer.hvalue2outlier_idx[hvalue] = outlier_idx + else: + for hook in tracer.hooks: + hook.remove() + + +class OutlierTracer(object): + _instance = None + + def __init__(self): + raise RuntimeError("Call get_instance() instead") + + def initialize(self, model): + self.last_w = None + self.current_outlier_dims = None + self.hvalues = [] + self.outliers = [] + self.hvalue2outlier_idx = {} + self.initialized = True + self.hooks = [] + + for n, m in model.named_modules(): + if isinstance(m, torch.nn.Linear): + self.hooks.append(m.register_forward_pre_hook(outlier_hook)) + + def is_initialized(self): + return getattr(self, 'initialized', False) + + def get_hvalue(self, weight): + return weight.data.storage().data_ptr() + + def get_outliers(self, weight): + if not self.is_initialized(): + print('Outlier tracer is not initialized...') + return None + hvalue = self.get_hvalue(weight) + if hvalue in self.hvalue2outlier_idx: + return self.hvalue2outlier_idx[hvalue] + else: + return None + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + return cls._instance + +def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False): + if rdm: + return torch.randint(0, weight.shape[1], size=(topk,), device=weight.device).long() + + m = weight.mean(reduction_dim) + mm = m.mean() + mstd = m.std() + zm = (m-mm)/mstd + + std = weight.std(reduction_dim) + stdm = std.mean() + stdstd = std.std() + + zstd = (std-stdm)/stdstd + + if topk is not None: + val, idx = torch.topk(std.abs(), k=topk, dim=0) + else: + idx = torch.where(zstd > zscore)[0] + + return idx + +def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None): + """ + Replace linear modules with a new Linear module. + + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + linear_replacement (`torch.nn.Module`): + The linear module that replaces the old one. Only expects standard arguments. + If other arguments need to be passed, use a lambda. + skip_modules (`List[str]`, *optional*, defaults to `lm_head`): + List of modules names not to convert. Defaults to `lm_head`. + copy_weights (`bool`): + Copy the weights from the old linear module to the new one + post_processing_fun_name (`str`): + A function name of the replacement linear class that is called + after processing. + """ + for name, module in model.named_children(): + if len(list(module.children())) > 0: + replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function) + + if isinstance(module, torch.nn.Linear) and name not in skip_modules: + old_module = model._modules[name] + model._modules[name] = linear_replacement( + module.in_features, + module.out_features, + module.bias is not None, + ) + if copy_weights: + model._modules[name].weight = old_module.weight + model._modules[name].bias = old_module.bias + + if post_processing_function is not None: + func = getattr(module, post_processing_function, None) + if func is not None: func(module) + return model + + def execute_and_return(command_string: str) -> Tuple[str, str]: def _decode(subprocess_err_out_tuple): diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 08b9b44..b32b39c 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -543,7 +543,9 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs // load code through read-only cache via __ldg #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH; j++) + { vals[j] = __ldg(&code[qvals[j]])*local_abs_max; + } __syncthreads(); StoreT(storet).Store(&(out[i]), vals, valid_items); diff --git a/tests/test_functional.py b/tests/test_functional.py index 69c200a..70fa4d0 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2109,6 +2109,7 @@ def test_few_bit_quant(): ebits = math.ceil(bits/2) pbits = bits-ebits-1 code = F.create_fp8_map(True, ebits, pbits, bits).cuda() + print(code) elif method == 'dynamic': code = F.create_dynamic_map(True, bits-0, bits).cuda() elif method == 'quantile': @@ -2181,7 +2182,9 @@ def test_kbit_quantile_estimation(): def test_bench_dequantization(): a = torch.rand(1024, 1024, device='cuda').half() - qa, SA = F.quantize_blockwise(a) + code =F.create_fp8_map(True, 3, 0, 4).cuda() + qa, SA = F.quantize_blockwise(a, code=code) + print(qa.max()) max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000 #print(max_theoretical_mu) @@ -2193,3 +2196,4 @@ def test_bench_dequantization(): torch.cuda.synchronize() #print((time.time()-t0)/1e6) +