2023-08-02 21:53:35 +00:00
|
|
|
from contextlib import contextmanager
|
|
|
|
|
2024-05-04 16:48:26 +00:00
|
|
|
import math
|
2023-08-02 21:53:35 +00:00
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
2024-05-04 16:48:26 +00:00
|
|
|
|
2023-08-02 23:36:26 +00:00
|
|
|
from ..config import cfg
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
Embedding = torch.nn.Embedding
|
|
|
|
Linear = torch.nn.Linear
|
|
|
|
|
2024-03-01 02:29:17 +00:00
|
|
|
# https://github.com/kyegomez/BitNet
|
2024-05-03 01:08:59 +00:00
|
|
|
if cfg.optimizations.bitnet:
|
2024-03-01 02:29:17 +00:00
|
|
|
from bitnet import BitLinear
|
|
|
|
|
2024-05-03 01:08:59 +00:00
|
|
|
if cfg.optimizations.bitsandbytes:
|
2023-08-02 21:53:35 +00:00
|
|
|
import bitsandbytes as bnb
|
2023-10-13 03:21:43 +00:00
|
|
|
|
2024-05-03 01:08:59 +00:00
|
|
|
if cfg.optimizations.linear:
|
2024-03-01 02:29:17 +00:00
|
|
|
|
2024-05-03 01:08:59 +00:00
|
|
|
if cfg.optimizations.bitnet:
|
2024-03-01 02:29:17 +00:00
|
|
|
Linear = BitLinear
|
|
|
|
else:
|
|
|
|
Linear = bnb.nn.Linear8bitLt
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2024-05-03 01:08:59 +00:00
|
|
|
if cfg.optimizations.embedding:
|
2023-10-13 03:21:43 +00:00
|
|
|
Embedding = bnb.nn.modules.Embedding
|
|
|
|
"""
|
2023-08-02 21:53:35 +00:00
|
|
|
Embedding.forward = lambda self, input: ( self.norm(F.embedding(
|
|
|
|
input,
|
|
|
|
self.weight,
|
|
|
|
self.padding_idx,
|
|
|
|
self.max_norm,
|
|
|
|
self.norm_type,
|
|
|
|
self.scale_grad_by_freq,
|
|
|
|
self.sparse,
|
|
|
|
)).to(self.weight.dtype) )
|
2023-10-13 03:21:43 +00:00
|
|
|
"""
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
|
2024-05-03 01:08:59 +00:00
|
|
|
if cfg.optimizations.bitsandbytes:
|
2023-08-02 21:53:35 +00:00
|
|
|
import bitsandbytes as bnb
|
|
|
|
|
2023-09-07 01:33:16 +00:00
|
|
|
Adam = bnb.optim.Adam8bit
|
|
|
|
AdamW = bnb.optim.AdamW8bit
|
|
|
|
SGD = bnb.optim.SGD8bit
|
2024-04-10 03:04:01 +00:00
|
|
|
Adagrad = bnb.optim.Adagrad8bit
|
2023-09-06 23:58:35 +00:00
|
|
|
else:
|
|
|
|
Adam = torch.optim.Adam
|
|
|
|
AdamW = torch.optim.AdamW
|
|
|
|
SGD = torch.optim.SGD
|
2024-04-10 03:04:01 +00:00
|
|
|
Adagrad = torch.optim.Adagrad
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
|
|
|
|
@contextmanager
|
|
|
|
def autocast(input, from_dtype, to_dtype):
|
|
|
|
if input.dtype == from_dtype:
|
|
|
|
input = input.to(to_dtype)
|
|
|
|
yield input
|
|
|
|
input = input.to(from_dtype)
|
|
|
|
else:
|
2023-08-02 23:36:26 +00:00
|
|
|
yield input
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def autocasts(input, from_dtype, to_dtype):
|
|
|
|
if input.dtype in from_dtype:
|
|
|
|
from_dtype = input.dtype
|
|
|
|
input = input.to(to_dtype)
|
|
|
|
yield input
|
|
|
|
input = input.to(from_dtype)
|
|
|
|
else:
|
|
|
|
yield input
|
|
|
|
|
|
|
|
# handles temporarily upcasting 'index tensors' so torch will stop bitching
|
|
|
|
def autocast_forward( func ):
|
|
|
|
def wrapper( self, input, *args, **kwargs ):
|
2024-04-09 19:41:13 +00:00
|
|
|
with autocasts( input, [torch.int16, torch.int8, torch.uint8, torch.float16, torch.bfloat16], torch.int32 ) as k:
|
2023-08-02 23:36:26 +00:00
|
|
|
return func( self, k, *args, **kwargs )
|
|
|
|
return wrapper
|
|
|
|
Embedding.forward = autocast_forward(Embedding.forward)
|
|
|
|
|
2024-05-03 01:08:59 +00:00
|
|
|
if cfg.optimizations.fp8:
|
2024-04-09 01:14:51 +00:00
|
|
|
import transformer_engine.pytorch as te
|
|
|
|
|
|
|
|
Linear = te.Linear
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def autocast():
|
|
|
|
yield te.fp8_autocast(enabled=True)
|
|
|
|
else:
|
|
|
|
@contextmanager
|
|
|
|
def autocast():
|
|
|
|
yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp)
|
|
|
|
|
2024-05-03 01:08:59 +00:00
|
|
|
if cfg.optimizations.injects and cfg.optimizations.bitsandbytes:
|
2023-08-02 23:36:26 +00:00
|
|
|
torch.nn.Linear = Linear
|
|
|
|
torch.nn.Embedding = Embedding
|
|
|
|
|
|
|
|
torch.optim.Adam = Adam
|
2023-09-06 23:58:35 +00:00
|
|
|
torch.optim.AdamW = AdamW
|
2023-09-07 01:33:16 +00:00
|
|
|
torch.optim.SGD = SGD
|
|
|
|
|
2024-03-02 02:18:43 +00:00
|
|
|
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
|
2024-05-04 16:48:26 +00:00
|
|
|
def replace_linear( model, verbose=False ):
|
2024-05-03 01:08:59 +00:00
|
|
|
bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
|
2024-04-16 15:19:02 +00:00
|
|
|
|
2024-03-02 01:20:10 +00:00
|
|
|
device = next(model.parameters()).device
|
2024-03-02 02:18:43 +00:00
|
|
|
linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
|
2024-05-04 16:48:26 +00:00
|
|
|
klass = Linear
|
|
|
|
|
2024-03-02 01:20:10 +00:00
|
|
|
for *parent, k in linears:
|
|
|
|
name = '.'.join(parent)
|
|
|
|
|
2024-05-03 01:08:59 +00:00
|
|
|
|
2024-03-02 01:20:10 +00:00
|
|
|
# copy parameters
|
2024-03-02 02:18:43 +00:00
|
|
|
m = getattr( model.get_submodule(name), k )
|
2024-03-02 01:20:10 +00:00
|
|
|
|
2024-05-04 16:48:26 +00:00
|
|
|
if isinstance(m, klass):
|
|
|
|
continue
|
|
|
|
|
2024-03-02 01:20:10 +00:00
|
|
|
in_features = m.in_features
|
|
|
|
out_features = m.out_features
|
2024-03-02 02:18:43 +00:00
|
|
|
bias = m.bias is not None
|
2024-03-02 01:20:10 +00:00
|
|
|
|
2024-04-16 15:19:02 +00:00
|
|
|
kwargs = dict(in_features=in_features, out_features=out_features, bias=bias) if not bnb else dict(input_features=in_features, output_features=out_features, bias=bias)
|
|
|
|
|
2024-03-02 02:18:43 +00:00
|
|
|
# overwrite
|
2024-03-02 01:20:10 +00:00
|
|
|
setattr(
|
2024-03-02 02:18:43 +00:00
|
|
|
model.get_submodule(name), k,
|
2024-05-03 01:08:59 +00:00
|
|
|
klass( **kwargs ).to(device=device, dtype=cfg.trainer.dtype)
|
2024-03-02 01:20:10 +00:00
|
|
|
)
|
2024-05-04 16:48:26 +00:00
|
|
|
|
|
|
|
if verbose:
|
|
|
|
print(f"Replacing {name}.{k} to", klass)
|
2024-03-02 01:20:10 +00:00
|
|
|
|
2024-04-16 15:19:02 +00:00
|
|
|
return model
|
2024-03-02 01:20:10 +00:00
|
|
|
|
2023-09-07 01:33:16 +00:00
|
|
|
# https://github.com/konstmish/prodigy
|
|
|
|
try:
|
|
|
|
from prodigyopt import Prodigy
|
|
|
|
except Exception as e:
|
|
|
|
pass
|