2022-08-01 10:31:48 +00:00
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
|
|
#
|
|
|
|
# This source code is licensed under the MIT license found in the
|
2021-10-06 02:16:20 +00:00
|
|
|
# LICENSE file in the root directory of this source tree.
|
2022-08-01 16:32:47 +00:00
|
|
|
from typing import (
|
|
|
|
Any,
|
|
|
|
Callable,
|
|
|
|
Dict,
|
|
|
|
Iterator,
|
|
|
|
Mapping,
|
|
|
|
Optional,
|
|
|
|
Set,
|
|
|
|
Tuple,
|
|
|
|
TypeVar,
|
|
|
|
Union,
|
|
|
|
overload,
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
import torch
|
2021-10-06 02:16:20 +00:00
|
|
|
import torch.nn.functional as F
|
2022-08-01 10:31:48 +00:00
|
|
|
from torch import Tensor, device, dtype, nn
|
|
|
|
from torch.nn.parameter import Parameter
|
2021-10-06 02:16:20 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
import bitsandbytes as bnb
|
2021-10-06 02:16:20 +00:00
|
|
|
from bitsandbytes.optim import GlobalOptimManager
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
T = TypeVar("T", bound="torch.nn.Module")
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2021-10-06 02:16:20 +00:00
|
|
|
class StableEmbedding(torch.nn.Embedding):
|
2022-08-01 10:31:48 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
num_embeddings: int,
|
|
|
|
embedding_dim: int,
|
|
|
|
padding_idx: Optional[int] = None,
|
|
|
|
max_norm: Optional[float] = None,
|
|
|
|
norm_type: float = 2.0,
|
|
|
|
scale_grad_by_freq: bool = False,
|
|
|
|
sparse: bool = False,
|
|
|
|
_weight: Optional[Tensor] = None,
|
|
|
|
) -> None:
|
|
|
|
super(StableEmbedding, self).__init__(
|
|
|
|
num_embeddings,
|
|
|
|
embedding_dim,
|
|
|
|
padding_idx,
|
|
|
|
max_norm,
|
|
|
|
norm_type,
|
|
|
|
scale_grad_by_freq,
|
|
|
|
sparse,
|
|
|
|
_weight,
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
self.norm = torch.nn.LayerNorm(embedding_dim)
|
2022-08-01 10:31:48 +00:00
|
|
|
GlobalOptimManager.get_instance().register_module_override(
|
|
|
|
self, "weight", {"optim_bits": 32}
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
def reset_parameters(self) -> None:
|
|
|
|
torch.nn.init.xavier_uniform_(self.weight)
|
|
|
|
self._fill_padding_idx_with_zero()
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
""" !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
|
2021-10-06 02:16:20 +00:00
|
|
|
to make the Layer compatible with Pytorch < 1.9.
|
|
|
|
This means that if this changes in future PyTorch releases this need to change too
|
|
|
|
which is cumbersome. However, with this we can ensure compatibility with previous
|
|
|
|
PyTorch releases.
|
2022-08-01 10:31:48 +00:00
|
|
|
"""
|
|
|
|
|
2021-10-06 02:16:20 +00:00
|
|
|
def _fill_padding_idx_with_zero(self) -> None:
|
|
|
|
if self.padding_idx is not None:
|
|
|
|
with torch.no_grad():
|
|
|
|
self.weight[self.padding_idx].fill_(0)
|
|
|
|
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
|
|
emb = F.embedding(
|
2022-08-01 10:31:48 +00:00
|
|
|
input,
|
|
|
|
self.weight,
|
|
|
|
self.padding_idx,
|
|
|
|
self.max_norm,
|
|
|
|
self.norm_type,
|
|
|
|
self.scale_grad_by_freq,
|
|
|
|
self.sparse,
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
return self.norm(emb)
|
2021-11-29 17:32:13 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Embedding(torch.nn.Embedding):
|
2022-08-01 10:31:48 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
num_embeddings: int,
|
|
|
|
embedding_dim: int,
|
|
|
|
padding_idx: Optional[int] = None,
|
|
|
|
max_norm: Optional[float] = None,
|
|
|
|
norm_type: float = 2.0,
|
|
|
|
scale_grad_by_freq: bool = False,
|
|
|
|
sparse: bool = False,
|
|
|
|
_weight: Optional[Tensor] = None,
|
|
|
|
) -> None:
|
|
|
|
super(Embedding, self).__init__(
|
|
|
|
num_embeddings,
|
|
|
|
embedding_dim,
|
|
|
|
padding_idx,
|
|
|
|
max_norm,
|
|
|
|
norm_type,
|
|
|
|
scale_grad_by_freq,
|
|
|
|
sparse,
|
|
|
|
_weight,
|
|
|
|
)
|
|
|
|
GlobalOptimManager.get_instance().register_module_override(
|
|
|
|
self, "weight", {"optim_bits": 32}
|
|
|
|
)
|
2021-11-29 17:32:13 +00:00
|
|
|
|
|
|
|
def reset_parameters(self) -> None:
|
|
|
|
torch.nn.init.xavier_uniform_(self.weight)
|
|
|
|
self._fill_padding_idx_with_zero()
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
""" !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
|
2021-11-29 17:32:13 +00:00
|
|
|
to make the Layer compatible with Pytorch < 1.9.
|
|
|
|
This means that if this changes in future PyTorch releases this need to change too
|
|
|
|
which is cumbersome. However, with this we can ensure compatibility with previous
|
|
|
|
PyTorch releases.
|
2022-08-01 10:31:48 +00:00
|
|
|
"""
|
|
|
|
|
2021-11-29 17:32:13 +00:00
|
|
|
def _fill_padding_idx_with_zero(self) -> None:
|
|
|
|
if self.padding_idx is not None:
|
|
|
|
with torch.no_grad():
|
|
|
|
self.weight[self.padding_idx].fill_(0)
|
|
|
|
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
|
|
emb = F.embedding(
|
2022-08-01 10:31:48 +00:00
|
|
|
input,
|
|
|
|
self.weight,
|
|
|
|
self.padding_idx,
|
|
|
|
self.max_norm,
|
|
|
|
self.norm_type,
|
|
|
|
self.scale_grad_by_freq,
|
|
|
|
self.sparse,
|
|
|
|
)
|
2021-11-29 17:32:13 +00:00
|
|
|
|
|
|
|
return emb
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
class Int8Params(torch.nn.Parameter):
|
2022-08-01 10:31:48 +00:00
|
|
|
def __new__(
|
2022-08-01 16:32:47 +00:00
|
|
|
cls,
|
|
|
|
data=None,
|
|
|
|
requires_grad=True,
|
|
|
|
has_fp16_weights=False,
|
|
|
|
CB=None,
|
|
|
|
SCB=None,
|
2022-08-01 10:31:48 +00:00
|
|
|
):
|
2022-07-22 21:41:05 +00:00
|
|
|
cls.has_fp16_weights = has_fp16_weights
|
|
|
|
cls.CB = None
|
|
|
|
cls.SCB = None
|
|
|
|
if data is None:
|
|
|
|
data = torch.empty(0)
|
|
|
|
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
|
|
|
|
|
|
|
def cuda(self, device):
|
|
|
|
if self.has_fp16_weights:
|
|
|
|
return super().cuda(device)
|
|
|
|
else:
|
|
|
|
# we store the 8-bit rows-major weight
|
|
|
|
# we convert this weight to the turning/ampere weight during the first inference pass
|
|
|
|
B = self.data.contiguous().half().cuda(device)
|
|
|
|
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
|
|
|
|
del CBt
|
|
|
|
del SCBt
|
|
|
|
self.data = CB
|
2022-08-01 10:31:48 +00:00
|
|
|
setattr(self, "CB", CB)
|
|
|
|
setattr(self, "SCB", SCB)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
@overload
|
2022-08-01 10:31:48 +00:00
|
|
|
def to(
|
|
|
|
self: T,
|
|
|
|
device: Optional[Union[int, device]] = ...,
|
|
|
|
dtype: Optional[Union[dtype, str]] = ...,
|
|
|
|
non_blocking: bool = ...,
|
|
|
|
) -> T:
|
2022-07-22 21:41:05 +00:00
|
|
|
...
|
|
|
|
|
|
|
|
@overload
|
|
|
|
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
|
|
|
|
...
|
|
|
|
|
|
|
|
@overload
|
|
|
|
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
|
|
|
|
...
|
|
|
|
|
|
|
|
def to(self, *args, **kwargs):
|
2022-08-01 10:31:48 +00:00
|
|
|
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
|
|
|
|
*args, **kwargs
|
|
|
|
)
|
|
|
|
|
|
|
|
if (
|
|
|
|
device is not None
|
|
|
|
and device.type == "cuda"
|
|
|
|
and self.data.device.type == "cpu"
|
|
|
|
):
|
|
|
|
return self.cuda(device)
|
2022-07-22 21:41:05 +00:00
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
new_param = Int8Params(
|
2022-08-01 16:32:47 +00:00
|
|
|
super().to(
|
|
|
|
device=device, dtype=dtype, non_blocking=non_blocking
|
|
|
|
),
|
2022-08-01 10:31:48 +00:00
|
|
|
requires_grad=self.requires_grad,
|
|
|
|
has_fp16_weights=self.has_fp16_weights,
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
new_param.CB = self.CB
|
|
|
|
new_param.SCB = self.SCB
|
|
|
|
|
|
|
|
return new_param
|
|
|
|
|
|
|
|
|
|
|
|
class Linear8bitLt(nn.Linear):
|
2022-08-01 10:31:48 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
input_features,
|
|
|
|
output_features,
|
|
|
|
bias=True,
|
|
|
|
has_fp16_weights=True,
|
|
|
|
threshold=0.0,
|
|
|
|
index=None,
|
|
|
|
):
|
2022-08-01 16:32:47 +00:00
|
|
|
super(Linear8bitLt, self).__init__(
|
|
|
|
input_features, output_features, bias
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
self.state = bnb.MatmulLtState()
|
2022-08-01 10:31:48 +00:00
|
|
|
self.index = index
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
self.state.threshold = threshold
|
|
|
|
self.state.has_fp16_weights = has_fp16_weights
|
|
|
|
if threshold > 0.0 and not has_fp16_weights:
|
|
|
|
self.state.use_pool = True
|
|
|
|
|
2022-08-01 16:32:47 +00:00
|
|
|
self.weight = Int8Params(
|
|
|
|
self.weight.data, has_fp16_weights=has_fp16_weights
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
def init_8bit_state(self):
|
|
|
|
self.state.CB = self.weight.CB
|
|
|
|
self.state.SCB = self.weight.SCB
|
|
|
|
self.weight.CB = None
|
|
|
|
self.weight.SCB = None
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
self.state.is_training = self.training
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
if self.weight.CB is not None:
|
|
|
|
self.init_8bit_state()
|
|
|
|
# assert not self.state.has_fp16_weights
|
|
|
|
# if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
out = bnb.matmul(x, self.weight, state=self.state)
|
|
|
|
|
|
|
|
if self.bias is not None:
|
|
|
|
out += self.bias.unsqueeze(0).expand_as(out)
|
|
|
|
|
|
|
|
if not self.state.has_fp16_weights and self.state.CB is not None:
|
|
|
|
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
|
|
|
# we no longer need the row-major weight
|
|
|
|
del self.state.CB
|
|
|
|
self.weight.data = self.state.CxB
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
class Linear8bit(nn.Linear):
|
2022-08-01 10:31:48 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
input_features,
|
|
|
|
output_features,
|
|
|
|
bias=True,
|
|
|
|
quant_type="vector",
|
|
|
|
index=None,
|
|
|
|
args=None,
|
|
|
|
sparse_decomp=False,
|
|
|
|
):
|
2022-07-22 21:41:05 +00:00
|
|
|
super(Linear8bit, self).__init__(input_features, output_features, bias)
|
|
|
|
self.quant_type = quant_type
|
|
|
|
self.index = index
|
|
|
|
self.args = args
|
|
|
|
self.iter = 0
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
self.iter += 1
|
|
|
|
if self.iter % self.args.clip_freq == 0:
|
|
|
|
with torch.no_grad():
|
2022-08-01 10:31:48 +00:00
|
|
|
maxval, maxidx = torch.topk(
|
|
|
|
torch.abs(self.weight.flatten()), k=self.args.clip_idx
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
if not dist.is_initialized() or dist.get_rank() == 0:
|
2022-08-01 10:31:48 +00:00
|
|
|
print("clip", maxval[-1].item())
|
2022-07-22 21:41:05 +00:00
|
|
|
self.weight.clip_(-maxval[-1], maxval[-1])
|
|
|
|
|
|
|
|
if self.args is not None:
|
2022-08-01 10:31:48 +00:00
|
|
|
out = bnb.nn.functional.sparse_decomposed_linear8bit(
|
|
|
|
x,
|
|
|
|
self.weight,
|
|
|
|
self.bias,
|
|
|
|
qval=self.args.sparse_decomp_val,
|
|
|
|
quant_type=self.args.quant_type,
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
out = bnb.nn.functional.linear8bit(
|
|
|
|
x, self.weight, self.bias, quant_type=self.args.quant_type
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
return out
|