vall-e/vall_e/utils/ml.py

233 lines
5.9 KiB
Python
Executable File

from contextlib import contextmanager
import math
import logging
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import _LRScheduler
from ..config import cfg
_logger = logging.getLogger(__name__)
Embedding = torch.nn.Embedding
Linear = torch.nn.Linear
Adam = torch.optim.Adam
AdamW = torch.optim.AdamW
SGD = torch.optim.SGD
Adagrad = torch.optim.Adagrad
Adafactor = torch.optim.Adafactor
OneCycleLR = torch.optim.lr_scheduler.OneCycleLR
CosineAnnealingLR = torch.optim.lr_scheduler.CosineAnnealingLR
LambdaLR = torch.optim.lr_scheduler.LambdaLR
# implements Noam scheduling
# it's cringe
class NoamLR(_LRScheduler):
def __init__(self, optimizer, warmup_steps, d_model=1024, last_epoch=-1):
self.base_factor = d_model ** (-0.5)
self.warmup_steps = warmup_steps
super().__init__(optimizer, last_epoch)
def get_lr(self):
step = max(1, self.last_epoch)
scale = self.base_factor * min(step ** (-0.5), step * self.warmup_steps ** (-1.5))
return [base_lr * scale for base_lr in self.base_lrs]
# gradually warms up LR then holds or decays
class WarmupLR(_LRScheduler):
def __init__(self, optimizer, warmup_steps, decay_factor=0.0, last_epoch=-1):
self.warmup_steps = warmup_steps
self.decay_factor = decay_factor
super().__init__(optimizer, last_epoch)
def get_lr(self):
step = self.last_epoch + 1
scale = 1
if step < self.warmup_steps:
scale = float(step) / float(max(1, self.warmup_steps))
elif self.decay_factor != 0:
scale = (1.0 - self.decay_factor) ** (step - self.warmup_steps)
return [base_lr * scale for base_lr in self.base_lrs]
# https://github.com/kyegomez/BitNet
if cfg.optimizations.bitnet:
from bitnet import BitLinear
if cfg.optimizations.bitsandbytes:
import bitsandbytes as bnb
if cfg.optimizations.linear:
if cfg.optimizations.bitnet:
Linear = BitLinear
else:
Linear = bnb.nn.Linear8bitLt
if cfg.optimizations.embedding:
Embedding = bnb.nn.StableEmbedding
"""
Embedding.forward = lambda self, input: ( self.norm(F.embedding(
input,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)).to(self.weight.dtype) )
"""
if cfg.optimizations.optimizers:
Adam = bnb.optim.Adam8bit
AdamW = bnb.optim.AdamW8bit
SGD = bnb.optim.SGD8bit
Adagrad = bnb.optim.Adagrad8bit
elif cfg.optimizations.dadaptation:
import dadaptation
if cfg.optimizations.optimizers:
Adam = dadaptation.DAdaptAdam
AdamW = dadaptation.DAdaptAdam
SGD = dadaptation.DAdaptSGD
AdaGrad = dadaptation.DAdaptAdaGrad
if cfg.optimizations.fp8:
import transformer_engine.pytorch as te
Linear = te.Linear
@contextmanager
def autocast():
yield te.fp8_autocast(enabled=True)
else:
@contextmanager
def autocast():
yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp)
if cfg.optimizations.injects:
if cfg.optimizations.linear:
torch.nn.Linear = Linear
if cfg.optimizations.embedding:
torch.nn.Embedding = Embedding
if cfg.optimizations.optimizers:
torch.optim.Adam = Adam
torch.optim.AdamW = AdamW
torch.optim.SGD = SGD
if cfg.optimizations.unsloth:
try:
from .ext.unsloth import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch
#apply_unsloth_offloaded_gradient_checkpoint_monkey_patch()
except Exception as e:
_logger.warning(f'Error while importing Unsloth: {str(e)}')
pass
class Optimizers(torch.optim.Optimizer):
def __init__(self, opts):
self.opts = opts
def step(self, *args, **kwargs):
for opt in self.opts:
opt.step(*args, **kwargs)
def zero_grad(self, *args, **kwargs):
for opt in self.opts:
opt.zero_grad(*args, **kwargs)
@property
def param_groups(self):
l = []
for opt in self.opts:
l += opt.param_groups
return l
def state_dict(self):
states = []
for i, opt in enumerate( self.opts ):
states.append( opt.state_dict() )
return states
def load_state_dict(self, state_dict):
for opt, state in zip( self.opts, state_dict ):
opt.load_state_dict( state )
try:
from .ext.apollo import Apollo
except Exception as e:
_logger.warning(f'Error while importing APOLLO: {str(e)}')
pass
try:
from .ext.muon import Muon
except Exception as e:
_logger.warning(f'Error while importing Muon: {str(e)}')
pass
# https://github.com/konstmish/prodigy
try:
from prodigyopt import Prodigy
except Exception as e:
_logger.warning(f'Error while importing Prodigyopt: {str(e)}')
pass
# https://github.com/facebookresearch/schedule_free/
try:
import schedulefree
except Exception as e:
_logger.warning(f'Error while importing Schedule_Free: {str(e)}')
pass
# backwards compat
from .utils import (
autocast_forward,
replace_linear as replace_linear_old,
replace_embedding as replace_embedding_old,
replace_attention,
resize_weight,
offload_model,
)
# wrapped here so we can maintain default args
def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ):
return replace_linear_old( model, klass, target, verbose )
def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ):
return replace_embedding_old( model, klass, target, verbose )
Embedding.forward = autocast_forward(Embedding.forward)
AVAILABLE_COMPILE_BACKENDS = []
try:
AVAILABLE_COMPILE_BACKENDS += torch._dynamo.list_backends()
except Exception as e:
pass
def compile_model(model, backend="auto"):
if not backend or backend == "auto":
backend = AVAILABLE_COMPILE_BACKENDS[0]
if backend not in AVAILABLE_COMPILE_BACKENDS:
return torch.compile(model)
return torch.compile(model, backend=backend)
if cfg.optimizations.tensorrt:
try:
import torch_tensorrt
AVAILABLE_COMPILE_BACKENDS.append("tensorrt")
except Exception as e:
_logger.warning(f'Error while importing TensorRT: {str(e)}')
pass