Added outlier detector and fake quantization layer.
This commit is contained in:
parent
1341fb44ad
commit
c9f505064e
|
@ -168,7 +168,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
|
||||||
values = []
|
values = []
|
||||||
lst = list(itertools.product([0, 1], repeat=precision_bits))
|
lst = list(itertools.product([0, 1], repeat=precision_bits))
|
||||||
#for ev in evalues:
|
#for ev in evalues:
|
||||||
bias = 2**(exponent_bits-1)-1
|
bias = 2**(exponent_bits-1)
|
||||||
for evalue in range(2**(exponent_bits)):
|
for evalue in range(2**(exponent_bits)):
|
||||||
for bit_pattern in lst:
|
for bit_pattern in lst:
|
||||||
value = (1 if evalue != 0 else 0)
|
value = (1 if evalue != 0 else 0)
|
||||||
|
@ -176,10 +176,10 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
|
||||||
value += pval*(2**-(i+1))
|
value += pval*(2**-(i+1))
|
||||||
if evalue == 0:
|
if evalue == 0:
|
||||||
# subnormals
|
# subnormals
|
||||||
value = value*2**-(bias-1)
|
value = value*2**-(bias)
|
||||||
else:
|
else:
|
||||||
# normals
|
# normals
|
||||||
value = value*2**-(evalue-bias-2)
|
value = value*2**-(evalue-bias-1)
|
||||||
values.append(value)
|
values.append(value)
|
||||||
if signed:
|
if signed:
|
||||||
values.append(-value)
|
values.append(-value)
|
||||||
|
|
|
@ -2,4 +2,4 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
from .modules import Int8Params, Linear8bitLt, StableEmbedding
|
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear
|
||||||
|
|
|
@ -10,6 +10,7 @@ from torch import Tensor, device, dtype, nn
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
from bitsandbytes.optim import GlobalOptimManager
|
from bitsandbytes.optim import GlobalOptimManager
|
||||||
|
from bitsandbytes.utils import OutlierTracer, find_outlier_dims
|
||||||
|
|
||||||
T = TypeVar("T", bound="torch.nn.Module")
|
T = TypeVar("T", bound="torch.nn.Module")
|
||||||
|
|
||||||
|
@ -133,6 +134,83 @@ class Embedding(torch.nn.Embedding):
|
||||||
|
|
||||||
return emb
|
return emb
|
||||||
|
|
||||||
|
class OutlierAwareLinear(nn.Linear):
|
||||||
|
def __init__(self, input_features, output_features, bias=True):
|
||||||
|
super().__init__(input_features, output_features, bias)
|
||||||
|
self.outlier_dim = None
|
||||||
|
self.is_quantized = False
|
||||||
|
|
||||||
|
def forward_with_outliers(self, x, outlier_idx):
|
||||||
|
raise NotImplementedError('Please override the `forward_with_outliers(self, x, outlier_idx)` function')
|
||||||
|
|
||||||
|
def quantize_weight(self, w, outlier_idx):
|
||||||
|
raise NotImplementedError('Please override the `quantize_weights(self, w, outlier_idx)` function')
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.outlier_dim is None:
|
||||||
|
tracer = OutlierTracer.get_instance()
|
||||||
|
if not tracer.is_initialized():
|
||||||
|
print('Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer')
|
||||||
|
outlier_idx = tracer.get_outliers(self.weight)
|
||||||
|
#print(outlier_idx, tracer.get_hvalue(self.weight))
|
||||||
|
self.outlier_dim = outlier_idx
|
||||||
|
|
||||||
|
if not self.is_quantized:
|
||||||
|
w = self.quantize_weight(self.weight, self.outlier_dim)
|
||||||
|
self.weight.data.copy_(w)
|
||||||
|
self.is_quantized = True
|
||||||
|
|
||||||
|
return self.forward_with_outliers(x, self.outlier_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class Fake4bitLinear(OutlierAwareLinear):
|
||||||
|
def __init__(self, input_features, output_features, bias=True, codebook=bnb.functional.create_fp8_map(True, 3, 0, total_bits=4)):
|
||||||
|
super().__init__(input_features, output_features, bias)
|
||||||
|
self.codebook = codebook
|
||||||
|
|
||||||
|
def quantize_weight(self, w, outlier_idx):
|
||||||
|
if outlier_idx.numel() > 0:
|
||||||
|
subw = w[:, outlier_idx].clone()
|
||||||
|
w[:, outlier_idx] = 0
|
||||||
|
wdtype = w.dtype
|
||||||
|
code = self.codebook.to(w.device)
|
||||||
|
cw, state = bnb.functional.quantize_blockwise(w, code=code, blocksize=64)
|
||||||
|
w = bnb.functional.dequantize_blockwise(cw, state, blocksize=64)
|
||||||
|
w = w.to(wdtype)
|
||||||
|
if outlier_idx.numel() > 0:
|
||||||
|
w[:, outlier_idx] = subw
|
||||||
|
self.is_quantized = True
|
||||||
|
return w
|
||||||
|
|
||||||
|
def forward_with_outliers(self, x, outlier_idx):
|
||||||
|
dims = torch.abs(x> 4).sum(dim=list(range(len(x.shape)-1)))
|
||||||
|
outlier_idx2 = torch.where(dims > 0)[0]
|
||||||
|
outlier_idx = torch.cat([outlier_idx, outlier_idx2]).unique()
|
||||||
|
n = x.shape[-1]
|
||||||
|
idx = torch.arange(n, device=x.device)
|
||||||
|
idx[outlier_idx] = -1
|
||||||
|
inverse_idx = torch.where(idx >= 0)[0]
|
||||||
|
if outlier_idx.numel() > 0:
|
||||||
|
subx = x[..., outlier_idx].clone()
|
||||||
|
#print(1, subx, 1)
|
||||||
|
#x[..., outlier_idx] = 0
|
||||||
|
inverse_x = x[...,inverse_idx]
|
||||||
|
xdtype = x.dtype
|
||||||
|
#code = bnb.functional.create_fp8_map(True, 4-3, 2, 4).to(x.device)
|
||||||
|
#code = bnb.functional.create_quantile_map(x, 4).to(x.device)
|
||||||
|
code = bnb.functional.create_dynamic_map(True, total_bits=4.0).to(x.device)
|
||||||
|
c, state = bnb.functional.quantize_blockwise(inverse_x, code=code, blocksize=64)
|
||||||
|
inverse_x = bnb.functional.dequantize_blockwise(c, state, blocksize=64)
|
||||||
|
#c, state = bnb.functional.quantize_blockwise(x, code=code, blocksize=64)
|
||||||
|
#x = bnb.functional.dequantize_blockwise(c, state, blocksize=64)
|
||||||
|
x = x.to(xdtype)
|
||||||
|
x[..., inverse_idx] = inverse_x.to(x.dtype)
|
||||||
|
#if outlier_idx.numel() > 0:
|
||||||
|
#x[..., outlier_idx] = subx
|
||||||
|
|
||||||
|
return torch.nn.functional.linear(x, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Int8Params(torch.nn.Parameter):
|
class Int8Params(torch.nn.Parameter):
|
||||||
def __new__(
|
def __new__(
|
||||||
|
|
|
@ -1,7 +1,143 @@
|
||||||
import shlex
|
import shlex
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import torch
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
def outlier_hook(module, input):
|
||||||
|
assert isinstance(module, torch.nn.Linear)
|
||||||
|
tracer = OutlierTracer.get_instance()
|
||||||
|
hvalue = tracer.get_hvalue(module.weight)
|
||||||
|
if hvalue not in tracer.hvalue2outlier_idx:
|
||||||
|
outlier_idx = find_outlier_dims(module.weight)
|
||||||
|
tracer.outliers.append(outlier_idx)
|
||||||
|
tracer.hvalues.append(hvalue)
|
||||||
|
if len(tracer.outliers) > 1:
|
||||||
|
# assign the current layer the outlier idx found from the weight
|
||||||
|
# of the previous linear layer
|
||||||
|
if tracer.outliers[-1].numel() > 0:
|
||||||
|
assert tracer.outliers[-1].max() < module.weight.shape[1]
|
||||||
|
tracer.hvalue2outlier_idx[hvalue] = tracer.outliers[-1]
|
||||||
|
|
||||||
|
else:
|
||||||
|
# first layer, we cannot use the weight for outlier detection
|
||||||
|
# we follow a mixed approach:
|
||||||
|
# (1) zscore test of std of hidden dimension
|
||||||
|
# (2) magnitude > 6 test
|
||||||
|
merged = input[0].view(-1, input[0].shape[-1])
|
||||||
|
# (1) zscore test of std of hidden dimension
|
||||||
|
outlier_idx = find_outlier_dims(merged, reduction_dim=1, zscore=3)
|
||||||
|
# (2) magnitude > 6 test
|
||||||
|
dims = (torch.abs(input[0])> 6).sum(dim=list(range(len(input[0].shape)-1)))
|
||||||
|
outlier_idx2 = torch.where(dims > 0)[0]
|
||||||
|
outlier_idx = torch.cat([outlier_idx, outlier_idx2]).unique()
|
||||||
|
tracer.hvalue2outlier_idx[hvalue] = outlier_idx
|
||||||
|
else:
|
||||||
|
for hook in tracer.hooks:
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
|
|
||||||
|
class OutlierTracer(object):
|
||||||
|
_instance = None
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
raise RuntimeError("Call get_instance() instead")
|
||||||
|
|
||||||
|
def initialize(self, model):
|
||||||
|
self.last_w = None
|
||||||
|
self.current_outlier_dims = None
|
||||||
|
self.hvalues = []
|
||||||
|
self.outliers = []
|
||||||
|
self.hvalue2outlier_idx = {}
|
||||||
|
self.initialized = True
|
||||||
|
self.hooks = []
|
||||||
|
|
||||||
|
for n, m in model.named_modules():
|
||||||
|
if isinstance(m, torch.nn.Linear):
|
||||||
|
self.hooks.append(m.register_forward_pre_hook(outlier_hook))
|
||||||
|
|
||||||
|
def is_initialized(self):
|
||||||
|
return getattr(self, 'initialized', False)
|
||||||
|
|
||||||
|
def get_hvalue(self, weight):
|
||||||
|
return weight.data.storage().data_ptr()
|
||||||
|
|
||||||
|
def get_outliers(self, weight):
|
||||||
|
if not self.is_initialized():
|
||||||
|
print('Outlier tracer is not initialized...')
|
||||||
|
return None
|
||||||
|
hvalue = self.get_hvalue(weight)
|
||||||
|
if hvalue in self.hvalue2outlier_idx:
|
||||||
|
return self.hvalue2outlier_idx[hvalue]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = cls.__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False):
|
||||||
|
if rdm:
|
||||||
|
return torch.randint(0, weight.shape[1], size=(topk,), device=weight.device).long()
|
||||||
|
|
||||||
|
m = weight.mean(reduction_dim)
|
||||||
|
mm = m.mean()
|
||||||
|
mstd = m.std()
|
||||||
|
zm = (m-mm)/mstd
|
||||||
|
|
||||||
|
std = weight.std(reduction_dim)
|
||||||
|
stdm = std.mean()
|
||||||
|
stdstd = std.std()
|
||||||
|
|
||||||
|
zstd = (std-stdm)/stdstd
|
||||||
|
|
||||||
|
if topk is not None:
|
||||||
|
val, idx = torch.topk(std.abs(), k=topk, dim=0)
|
||||||
|
else:
|
||||||
|
idx = torch.where(zstd > zscore)[0]
|
||||||
|
|
||||||
|
return idx
|
||||||
|
|
||||||
|
def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None):
|
||||||
|
"""
|
||||||
|
Replace linear modules with a new Linear module.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
model (`torch.nn.Module`):
|
||||||
|
Input model or `torch.nn.Module` as the function is run recursively.
|
||||||
|
linear_replacement (`torch.nn.Module`):
|
||||||
|
The linear module that replaces the old one. Only expects standard arguments.
|
||||||
|
If other arguments need to be passed, use a lambda.
|
||||||
|
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
|
||||||
|
List of modules names not to convert. Defaults to `lm_head`.
|
||||||
|
copy_weights (`bool`):
|
||||||
|
Copy the weights from the old linear module to the new one
|
||||||
|
post_processing_fun_name (`str`):
|
||||||
|
A function name of the replacement linear class that is called
|
||||||
|
after processing.
|
||||||
|
"""
|
||||||
|
for name, module in model.named_children():
|
||||||
|
if len(list(module.children())) > 0:
|
||||||
|
replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function)
|
||||||
|
|
||||||
|
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
|
||||||
|
old_module = model._modules[name]
|
||||||
|
model._modules[name] = linear_replacement(
|
||||||
|
module.in_features,
|
||||||
|
module.out_features,
|
||||||
|
module.bias is not None,
|
||||||
|
)
|
||||||
|
if copy_weights:
|
||||||
|
model._modules[name].weight = old_module.weight
|
||||||
|
model._modules[name].bias = old_module.bias
|
||||||
|
|
||||||
|
if post_processing_function is not None:
|
||||||
|
func = getattr(module, post_processing_function, None)
|
||||||
|
if func is not None: func(module)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def execute_and_return(command_string: str) -> Tuple[str, str]:
|
def execute_and_return(command_string: str) -> Tuple[str, str]:
|
||||||
def _decode(subprocess_err_out_tuple):
|
def _decode(subprocess_err_out_tuple):
|
||||||
|
|
|
@ -543,7 +543,9 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
|
||||||
// load code through read-only cache via __ldg
|
// load code through read-only cache via __ldg
|
||||||
#pragma unroll NUM_PER_TH
|
#pragma unroll NUM_PER_TH
|
||||||
for(int j = 0; j < NUM_PER_TH; j++)
|
for(int j = 0; j < NUM_PER_TH; j++)
|
||||||
|
{
|
||||||
vals[j] = __ldg(&code[qvals[j]])*local_abs_max;
|
vals[j] = __ldg(&code[qvals[j]])*local_abs_max;
|
||||||
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
StoreT(storet).Store(&(out[i]), vals, valid_items);
|
StoreT(storet).Store(&(out[i]), vals, valid_items);
|
||||||
|
|
|
@ -2109,6 +2109,7 @@ def test_few_bit_quant():
|
||||||
ebits = math.ceil(bits/2)
|
ebits = math.ceil(bits/2)
|
||||||
pbits = bits-ebits-1
|
pbits = bits-ebits-1
|
||||||
code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
|
code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
|
||||||
|
print(code)
|
||||||
elif method == 'dynamic':
|
elif method == 'dynamic':
|
||||||
code = F.create_dynamic_map(True, bits-0, bits).cuda()
|
code = F.create_dynamic_map(True, bits-0, bits).cuda()
|
||||||
elif method == 'quantile':
|
elif method == 'quantile':
|
||||||
|
@ -2181,7 +2182,9 @@ def test_kbit_quantile_estimation():
|
||||||
|
|
||||||
def test_bench_dequantization():
|
def test_bench_dequantization():
|
||||||
a = torch.rand(1024, 1024, device='cuda').half()
|
a = torch.rand(1024, 1024, device='cuda').half()
|
||||||
qa, SA = F.quantize_blockwise(a)
|
code =F.create_fp8_map(True, 3, 0, 4).cuda()
|
||||||
|
qa, SA = F.quantize_blockwise(a, code=code)
|
||||||
|
print(qa.max())
|
||||||
|
|
||||||
max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000
|
max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000
|
||||||
#print(max_theoretical_mu)
|
#print(max_theoretical_mu)
|
||||||
|
@ -2193,3 +2196,4 @@ def test_bench_dequantization():
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
#print((time.time()-t0)/1e6)
|
#print((time.time()-t0)/1e6)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user