from typing import Optional, TypeVar, Union, overload import torch import torch.nn.functional as F from torch import Tensor, device, dtype, nn import bitsandbytes as bnb from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.utils import OutlierTracer, find_outlier_dims T = TypeVar("T", bound="torch.nn.Module") class LinearFP8Mixed(nn.Linear): def __init__(self, input_features, output_features, bias=True): super().__init__(input_features, output_features, bias) self.bw_code = None self.fw_code = None array = [4096, 2048, 1024, 512, 256, 128, 64, 0] for i, k in enumerate(array): if input_features > array[i + 1]: self.bsz = k break for i, k in enumerate(array): if output_features > array[i + 1]: self.bsz2 = k break def forward(self, x: torch.Tensor): if self.fw_code is None: self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) out = bnb.research.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) if self.bias is not None: out += self.bias return out class LinearFP8Global(nn.Linear): def __init__(self, input_features, output_features, bias=True): super().__init__(input_features, output_features, bias) self.bw_code = None self.fw_code = None array = [4096, 2048, 1024, 512, 256, 128, 64, 0] for i, k in enumerate(array): if input_features > array[i + 1]: self.bsz = k break for i, k in enumerate(array): if output_features > array[i + 1]: self.bsz2 = k break def forward(self, x: torch.Tensor): if self.fw_code is None: self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) out = bnb.matmul_fp8_global(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) if self.bias is not None: out += self.bias return out