2023-08-02 21:53:35 +00:00
|
|
|
from contextlib import contextmanager
|
|
|
|
|
2024-05-04 16:48:26 +00:00
|
|
|
import math
|
2023-08-02 21:53:35 +00:00
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
2024-08-29 18:27:16 +00:00
|
|
|
import logging
|
2024-05-04 16:48:26 +00:00
|
|
|
|
2023-08-02 23:36:26 +00:00
|
|
|
from ..config import cfg
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2024-08-29 18:27:16 +00:00
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
Embedding = torch.nn.Embedding
|
|
|
|
Linear = torch.nn.Linear
|
|
|
|
|
2024-05-10 01:28:20 +00:00
|
|
|
Adam = torch.optim.Adam
|
|
|
|
AdamW = torch.optim.AdamW
|
|
|
|
SGD = torch.optim.SGD
|
|
|
|
Adagrad = torch.optim.Adagrad
|
|
|
|
|
2024-03-01 02:29:17 +00:00
|
|
|
# https://github.com/kyegomez/BitNet
|
2024-05-03 01:08:59 +00:00
|
|
|
if cfg.optimizations.bitnet:
|
2024-03-01 02:29:17 +00:00
|
|
|
from bitnet import BitLinear
|
|
|
|
|
2024-05-03 01:08:59 +00:00
|
|
|
if cfg.optimizations.bitsandbytes:
|
2023-08-02 21:53:35 +00:00
|
|
|
import bitsandbytes as bnb
|
2023-10-13 03:21:43 +00:00
|
|
|
|
2024-05-03 01:08:59 +00:00
|
|
|
if cfg.optimizations.linear:
|
2024-03-01 02:29:17 +00:00
|
|
|
|
2024-05-03 01:08:59 +00:00
|
|
|
if cfg.optimizations.bitnet:
|
2024-03-01 02:29:17 +00:00
|
|
|
Linear = BitLinear
|
|
|
|
else:
|
|
|
|
Linear = bnb.nn.Linear8bitLt
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2024-05-03 01:08:59 +00:00
|
|
|
if cfg.optimizations.embedding:
|
2023-10-13 03:21:43 +00:00
|
|
|
Embedding = bnb.nn.modules.Embedding
|
|
|
|
"""
|
2023-08-02 21:53:35 +00:00
|
|
|
Embedding.forward = lambda self, input: ( self.norm(F.embedding(
|
|
|
|
input,
|
|
|
|
self.weight,
|
|
|
|
self.padding_idx,
|
|
|
|
self.max_norm,
|
|
|
|
self.norm_type,
|
|
|
|
self.scale_grad_by_freq,
|
|
|
|
self.sparse,
|
|
|
|
)).to(self.weight.dtype) )
|
2023-10-13 03:21:43 +00:00
|
|
|
"""
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2024-05-10 01:28:20 +00:00
|
|
|
if cfg.optimizations.optimizers:
|
|
|
|
Adam = bnb.optim.Adam8bit
|
|
|
|
AdamW = bnb.optim.AdamW8bit
|
|
|
|
SGD = bnb.optim.SGD8bit
|
|
|
|
Adagrad = bnb.optim.Adagrad8bit
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2024-05-10 01:28:20 +00:00
|
|
|
elif cfg.optimizations.dadaptation:
|
|
|
|
import dadaptation
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2024-05-10 01:28:20 +00:00
|
|
|
if cfg.optimizations.optimizers:
|
|
|
|
Adam = dadaptation.DAdaptAdam
|
|
|
|
AdamW = dadaptation.DAdaptAdam
|
|
|
|
SGD = dadaptation.DAdaptSGD
|
|
|
|
AdaGrad = dadaptation.DAdaptAdaGrad
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2024-05-03 01:08:59 +00:00
|
|
|
if cfg.optimizations.fp8:
|
2024-04-09 01:14:51 +00:00
|
|
|
import transformer_engine.pytorch as te
|
|
|
|
|
|
|
|
Linear = te.Linear
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def autocast():
|
|
|
|
yield te.fp8_autocast(enabled=True)
|
|
|
|
else:
|
|
|
|
@contextmanager
|
|
|
|
def autocast():
|
|
|
|
yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp)
|
|
|
|
|
2024-05-10 01:28:20 +00:00
|
|
|
if cfg.optimizations.injects:
|
|
|
|
if cfg.optimizations.linear:
|
|
|
|
torch.nn.Linear = Linear
|
|
|
|
|
|
|
|
if cfg.optimizations.embedding:
|
|
|
|
torch.nn.Embedding = Embedding
|
2023-08-02 23:36:26 +00:00
|
|
|
|
2024-05-10 01:28:20 +00:00
|
|
|
if cfg.optimizations.optimizers:
|
|
|
|
torch.optim.Adam = Adam
|
|
|
|
torch.optim.AdamW = AdamW
|
|
|
|
torch.optim.SGD = SGD
|
2023-09-07 01:33:16 +00:00
|
|
|
|
2024-08-04 03:10:21 +00:00
|
|
|
AVAILABLE_COMPILE_BACKENDS = []
|
|
|
|
|
|
|
|
try:
|
|
|
|
AVAILABLE_COMPILE_BACKENDS += torch._dynamo.list_backends()
|
|
|
|
except Exception as e:
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
if cfg.optimizations.tensorrt:
|
|
|
|
try:
|
|
|
|
import torch_tensorrt
|
|
|
|
AVAILABLE_COMPILE_BACKENDS.append("tensorrt")
|
|
|
|
except Exception as e:
|
2024-08-29 18:27:16 +00:00
|
|
|
_logger.warning(f'Error while importing TensorRT: {str(e)}')
|
2024-08-04 03:10:21 +00:00
|
|
|
pass
|
|
|
|
|
2024-10-22 23:12:39 +00:00
|
|
|
if cfg.optimizations.unsloth:
|
|
|
|
try:
|
2024-12-11 02:13:21 +00:00
|
|
|
from .ext.unsloth import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch
|
2024-10-22 23:12:39 +00:00
|
|
|
#apply_unsloth_offloaded_gradient_checkpoint_monkey_patch()
|
|
|
|
except Exception as e:
|
|
|
|
_logger.warning(f'Error while importing Unsloth: {str(e)}')
|
|
|
|
pass
|
|
|
|
|
2024-12-11 02:13:21 +00:00
|
|
|
try:
|
|
|
|
from .ext.apollo import Apollo
|
|
|
|
except Exception as e:
|
|
|
|
_logger.warning(f'Error while importing APOLLO: {str(e)}')
|
|
|
|
pass
|
|
|
|
|
2024-08-04 03:10:21 +00:00
|
|
|
def compile_model(model, backend="auto"):
|
|
|
|
if not backend or backend == "auto":
|
|
|
|
backend = AVAILABLE_COMPILE_BACKENDS[0]
|
|
|
|
|
|
|
|
if backend not in AVAILABLE_COMPILE_BACKENDS:
|
|
|
|
return torch.compile(model)
|
|
|
|
|
|
|
|
return torch.compile(model, backend=backend)
|
|
|
|
|
2023-09-07 01:33:16 +00:00
|
|
|
# https://github.com/konstmish/prodigy
|
|
|
|
try:
|
|
|
|
from prodigyopt import Prodigy
|
|
|
|
except Exception as e:
|
2024-08-29 18:27:16 +00:00
|
|
|
_logger.warning(f'Error while importing Prodigyopt: {str(e)}')
|
2024-05-10 01:28:20 +00:00
|
|
|
pass
|
|
|
|
|
|
|
|
# https://github.com/facebookresearch/schedule_free/
|
|
|
|
try:
|
|
|
|
import schedulefree
|
|
|
|
except Exception as e:
|
2024-08-29 18:27:16 +00:00
|
|
|
_logger.warning(f'Error while importing Schedule_Free: {str(e)}')
|
2024-08-02 01:12:06 +00:00
|
|
|
pass
|
|
|
|
|
|
|
|
# backwards compat
|
|
|
|
from .utils import (
|
|
|
|
autocast_forward,
|
2024-08-02 03:43:39 +00:00
|
|
|
replace_linear as replace_linear_old,
|
2024-08-02 01:12:06 +00:00
|
|
|
replace_embedding as replace_embedding_old,
|
2024-08-02 03:43:39 +00:00
|
|
|
replace_attention,
|
2024-08-02 01:12:06 +00:00
|
|
|
resize_weight,
|
|
|
|
offload_model,
|
|
|
|
)
|
|
|
|
|
|
|
|
# wrapped here so we can maintain default args
|
|
|
|
def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ):
|
2024-08-02 03:43:39 +00:00
|
|
|
return replace_linear_old( model, klass, target, verbose )
|
2024-08-02 01:12:06 +00:00
|
|
|
def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ):
|
2024-08-02 03:43:39 +00:00
|
|
|
return replace_embedding_old( model, klass, target, verbose )
|
2024-08-02 01:12:06 +00:00
|
|
|
|
|
|
|
Embedding.forward = autocast_forward(Embedding.forward)
|