bitsandbytes-rocm/bitsandbytes/nn/modules.py

316 lines
10 KiB
Python
Raw Normal View History

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
2021-10-06 02:16:20 +00:00
# LICENSE file in the root directory of this source tree.
2022-10-27 11:32:01 +00:00
from typing import Optional, TypeVar, Union, overload
2021-10-06 02:16:20 +00:00
import torch
2021-10-06 02:16:20 +00:00
import torch.nn.functional as F
from torch import Tensor, device, dtype, nn
2021-10-06 02:16:20 +00:00
import bitsandbytes as bnb
2021-10-06 02:16:20 +00:00
from bitsandbytes.optim import GlobalOptimManager
T = TypeVar("T", bound="torch.nn.Module")
2022-07-22 21:41:05 +00:00
2021-10-06 02:16:20 +00:00
class StableEmbedding(torch.nn.Embedding):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
_weight: Optional[Tensor] = None,
device=None,
dtype=None,
) -> None:
super().__init__(
num_embeddings,
embedding_dim,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse,
_weight,
device,
dtype,
)
self.norm = torch.nn.LayerNorm(embedding_dim, device=device)
GlobalOptimManager.get_instance().register_module_override(
self, "weight", {"optim_bits": 32}
)
2021-10-06 02:16:20 +00:00
def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight)
self._fill_padding_idx_with_zero()
""" !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
2021-10-06 02:16:20 +00:00
to make the Layer compatible with Pytorch < 1.9.
This means that if this changes in future PyTorch releases this need to change too
which is cumbersome. However, with this we can ensure compatibility with previous
PyTorch releases.
"""
2021-10-06 02:16:20 +00:00
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input: Tensor) -> Tensor:
emb = F.embedding(
input,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
2021-10-06 02:16:20 +00:00
# always apply layer norm in full precision
emb = emb.to(torch.get_default_dtype())
return self.norm(emb).to(self.weight.dtype)
class Embedding(torch.nn.Embedding):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
_weight: Optional[Tensor] = None,
) -> None:
super().__init__(
num_embeddings,
embedding_dim,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse,
_weight,
)
GlobalOptimManager.get_instance().register_module_override(
self, "weight", {"optim_bits": 32}
)
def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight)
self._fill_padding_idx_with_zero()
""" !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
to make the Layer compatible with Pytorch < 1.9.
This means that if this changes in future PyTorch releases this need to change too
which is cumbersome. However, with this we can ensure compatibility with previous
PyTorch releases.
"""
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input: Tensor) -> Tensor:
emb = F.embedding(
input,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
return emb
2022-07-22 21:41:05 +00:00
2023-02-05 05:11:21 +00:00
class FP4Params(torch.nn.Parameter):
def __new__(cls, data=None, requires_grad=True, quant_state=None):
cls.quant_state = None
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad)
def cuda(self, device):
w = self.data.contiguous().half().cuda(device)
w_fp4, quant_state = bnb.functional.quantize_fp4(w)
self.data = w_fp4
self.quant_state = quant_state
return self
@overload
def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T:
...
@overload
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
...
@overload
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
...
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"):
return self.cuda(device)
else:
new_param = FP4Params(super().to(device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad, quant_state=self.quant_state)
return new_param
class LinearFP4(nn.Linear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
self.state = bnb.MatmulLtState()
self.weight = FP4Params(self.weight.data, requires_grad=False)
def init_8bit_state(self):
pass
def forward(self, x: torch.Tensor):
self.state.is_training = self.training
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
if getattr(self.weight, 'quant_state', None) is None:
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
2023-02-24 18:17:57 +00:00
inp_dtype = x.dtype
x = x.to(torch.float16)
out = bnb.matmul_fp4(x, self.weight.t(), bias=self.bias.half(), quant_state=self.weight.quant_state)
out = out.to(inp_dtype)
2023-02-05 05:11:21 +00:00
return out
2022-07-22 21:41:05 +00:00
class Int8Params(torch.nn.Parameter):
def __new__(
cls,
data=None,
requires_grad=True,
has_fp16_weights=False,
CB=None,
SCB=None,
):
2022-07-22 21:41:05 +00:00
cls.has_fp16_weights = has_fp16_weights
cls.CB = None
cls.SCB = None
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad)
def cuda(self, device):
if self.has_fp16_weights:
return super().cuda(device)
else:
# we store the 8-bit rows-major weight
# we convert this weight to the turning/ampere weight during the first inference pass
B = self.data.contiguous().half().cuda(device)
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
del CBt
2022-08-25 16:09:23 +00:00
del SCBt
2022-07-22 21:41:05 +00:00
self.data = CB
setattr(self, "CB", CB)
setattr(self, "SCB", SCB)
2022-07-22 21:41:05 +00:00
return self
@overload
def to(
self: T,
device: Optional[Union[int, device]] = ...,
dtype: Optional[Union[dtype, str]] = ...,
non_blocking: bool = ...,
) -> T:
2022-07-22 21:41:05 +00:00
...
@overload
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
...
@overload
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
...
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
*args, **kwargs
)
if (
device is not None
and device.type == "cuda"
and self.data.device.type == "cpu"
):
return self.cuda(device)
2022-07-22 21:41:05 +00:00
else:
new_param = Int8Params(
super().to(
device=device, dtype=dtype, non_blocking=non_blocking
),
requires_grad=self.requires_grad,
has_fp16_weights=self.has_fp16_weights,
)
2022-07-22 21:41:05 +00:00
new_param.CB = self.CB
new_param.SCB = self.SCB
return new_param
2023-02-05 05:11:21 +00:00
2022-07-22 21:41:05 +00:00
class Linear8bitLt(nn.Linear):
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
memory_efficient_backward=False, threshold=0.0, index=None):
super().__init__(input_features, output_features, bias)
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
2022-07-22 21:41:05 +00:00
self.state = bnb.MatmulLtState()
self.index = index
2022-07-22 21:41:05 +00:00
self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights
2022-09-11 02:51:29 +00:00
self.state.memory_efficient_backward = memory_efficient_backward
2022-07-22 21:41:05 +00:00
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
2022-07-22 21:41:05 +00:00
def init_8bit_state(self):
self.state.CB = self.weight.CB
self.state.SCB = self.weight.SCB
self.weight.CB = None
self.weight.SCB = None
def forward(self, x: torch.Tensor):
2022-07-22 21:41:05 +00:00
self.state.is_training = self.training
if self.weight.CB is not None:
self.init_8bit_state()
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
2022-07-22 21:41:05 +00:00
2022-08-16 19:00:54 +00:00
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
2022-09-11 02:51:29 +00:00
if not self.state.has_fp16_weights:
if self.state.CB is not None and self.state.CxB is not None:
2022-09-11 02:51:29 +00:00
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del self.state.CB
self.weight.data = self.state.CxB
2022-07-22 21:41:05 +00:00
return out