added muon optimizer through kludge hacks because it necessitates a second optimizer in tandum that seems to only sometimes work with deepspeed

This commit is contained in:
mrq 2025-02-23 11:22:13 -06:00
parent 67a6009555
commit 6634d07576
10 changed files with 110 additions and 60 deletions

View File

@ -11,7 +11,7 @@ elif cfg.trainer.backend == "local":
from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
from ..models import get_models, get_model
from ..utils import wrapper as ml
from ..utils import ml
from ..utils.io import torch_save, torch_load, pick_path
from ..models.lora import apply_lora, lora_load_state_dict
@ -114,7 +114,6 @@ def load_engines(training=True, **model_kwargs):
"lr": cfg.hyperparameters.learning_rate,
}
if cfg.hyperparameters.optimizer.lower() == "adamw":
params["betas"] = (0.9, 0.96)
params["eps"] = 1e-07
@ -149,8 +148,21 @@ def load_engines(training=True, **model_kwargs):
else:
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
muon_params = cfg.hyperparameters.optimizer_params.pop("muon", None)
params.update(cfg.hyperparameters.optimizer_params)
optimizer = optimizer_class(**params)
if muon_params is not None:
muon_params["params"] = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 and f'model.{name}' not in model.config.frozen_params ]
params["params"] = [ param for name, param in model.model.named_parameters() if param.ndim < 2 and f'model.{name}' not in model.config.frozen_params ]
params["params"] += [ param for name, param in model.named_parameters() if not name.startswith('model.') and name not in model.config.frozen_params ]
optimizer = ml.Optimizers([
ml.Muon(**muon_params),
optimizer_class(**params),
])
else:
optimizer = optimizer_class(**params)
if cfg.hyperparameters.scheduler.lower() == "schedulefree":
if cfg.hyperparameters.optimizer.lower() == "adamw":
@ -406,7 +418,7 @@ def load_engines(training=True, **model_kwargs):
try:
engine.wandb = wandb.init(project=key_name, **kwargs)
engine.wandb.watch(engine.module)
except Exception as e:
except Exception as e:
engine.wandb = None
else:
engine.wandb = None

View File

@ -45,7 +45,7 @@ from typing import Any, Protocol
from functools import cached_property
from .base import TrainFeeder
from ..utils import wrapper as ml
from ..utils import ml
_logger = logging.getLogger(__name__)

View File

@ -25,7 +25,7 @@ from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distr
from deepspeed.accelerator import get_accelerator
from ..utils.distributed import init_distributed, distributed_initialized
from ..utils import wrapper as ml
from ..utils import ml
from ..models.lora import freeze_non_lora_weights

View File

@ -17,7 +17,7 @@ from .emb import g2p, qnt
from .emb.qnt import trim, trim_random, unload_model, repeat_extend_audio
from .emb.transcribe import transcribe
from .utils import to_device, set_seed, clamp, wrapper as ml
from .utils import to_device, set_seed, clamp, ml
from .config import cfg, Config
from .models import get_models

View File

@ -1189,8 +1189,13 @@ class AR_NAR(Base):
break
for i, l in enumerate( sequence_list ):
index = (l == audio_stop_token).nonzero()[:, 0].min()
sequence_list[i] = sequence_list[i][:index]
index = (l == audio_stop_token).nonzero()
# kludge for when it doesnt actually hit a stop token but i cant be bothered to properly address it right now since it only came up in test training at the moment
try:
index = index[:, 0].min()
sequence_list[i] = sequence_list[i][:index]
except Exception as e:
pass
return sequence_list
@ -1362,8 +1367,6 @@ class AR_NAR(Base):
def example_usage():
cfg.device = "cuda"
cfg.trainer.backend = "local"
if cfg.audio_backend == "dac":
cfg.sample_rate = 44_100
from functools import partial
from einops import repeat
@ -1372,7 +1375,7 @@ def example_usage():
from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio
from ..data import _load_artifact
from ..engines import Engine, Engines
from ..utils import wrapper as ml
from ..utils import ml
from ..utils import setup_logging
import numpy as np
@ -1403,7 +1406,7 @@ def example_usage():
'n_text_tokens': cfg.model.text_tokens,
'n_audio_tokens': cfg.model.audio_tokens,
'd_model': 1536, # 256, # 1024, # 1536
'd_model': 1024, # 256, # 1024, # 1536
'n_heads': 16, # 4, # 16, # 24
'n_layers': 12, # 32
'n_experts': 1 if not cfg.model else cfg.model.experts,
@ -1425,7 +1428,7 @@ def example_usage():
available_tasks = ["tts-nar"]
model = AR_NAR(**kwargs).to(cfg.device)
steps = 250 // batch_size
steps = 100 // batch_size
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
@ -1464,28 +1467,23 @@ def example_usage():
learning_rate = 0.01
optimizer = ml.Apollo
"""
target_params = []
target_modules_list = ["attn", "mlp"]
for module_name, module in model.named_modules():
if not (isinstance(module, torch.nn.Linear)):
continue
if not any(target_key in module_name for target_key in target_modules_list):
continue
target_params.append(module.weight)
param_ids = [id(p) for p in target_params]
regular_params = [p for p in model.parameters() if id(p) not in param_ids]
params = [{'params': regular_params}, {'params': target_params, 'rank': 1, 'proj': 'random', 'scale_type': 'tensor', 'scale': 128,'update_proj_gap': 200, 'proj_type': 'std'}]
"""
params = [{'params': params, 'rank': 1, 'proj': 'random', 'scale_type': 'tensor', 'scale': 128,'update_proj_gap': 200, 'proj_type': 'std'}]
else:
raise ValueError(f"Unrecognized optimizer: {optimizer}")
_logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
optimizer = optimizer(params, lr=learning_rate)
muon_params = cfg.hyperparameters.optimizer_params.pop("muon", None)
if muon_params is not None:
muon_params["params"] = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 ]
adam_params = [ param for name, param in model.model.named_parameters() if param.ndim < 2 ] + [ param for name, param in model.named_parameters() if not name.startswith('model.') ]
optimizer = ml.Optimizers([
ml.Muon(**muon_params),
optimizer(adam_params, lr=learning_rate)
])
else:
optimizer = optimizer(params, lr=learning_rate)
if scheduler == "schedulefree":
if isinstance(optimizer, ml.AdamW):

View File

@ -14,7 +14,7 @@ from einops import rearrange
from torch import Tensor, einsum, nn
from torch.utils.checkpoint import checkpoint
from ...utils import wrapper as ml
from ...utils import ml
class AdaLN(nn.Module):
def __init__(self, d_model, n_levels, eps=1e-5, k=0.1, c=2):

View File

@ -30,7 +30,7 @@ from torch.utils.checkpoint import checkpoint
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
from .arch import *
from ..utils import wrapper as ml, clamp
from ..utils import ml, clamp
from ..samplers import *
# yuck, kind of needed

View File

@ -85,22 +85,6 @@ if cfg.optimizations.injects:
torch.optim.AdamW = AdamW
torch.optim.SGD = SGD
AVAILABLE_COMPILE_BACKENDS = []
try:
AVAILABLE_COMPILE_BACKENDS += torch._dynamo.list_backends()
except Exception as e:
pass
if cfg.optimizations.tensorrt:
try:
import torch_tensorrt
AVAILABLE_COMPILE_BACKENDS.append("tensorrt")
except Exception as e:
_logger.warning(f'Error while importing TensorRT: {str(e)}')
pass
if cfg.optimizations.unsloth:
try:
from .ext.unsloth import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch
@ -109,20 +93,48 @@ if cfg.optimizations.unsloth:
_logger.warning(f'Error while importing Unsloth: {str(e)}')
pass
class Optimizers(torch.optim.Optimizer):
def __init__(self, opts):
self.opts = opts
def step(self, *args, **kwargs):
for opt in self.opts:
opt.step(*args, **kwargs)
def zero_grad(self, *args, **kwargs):
for opt in self.opts:
opt.zero_grad(*args, **kwargs)
@property
def param_groups(self):
l = []
for opt in self.opts:
l += opt.param_groups
return l
def state_dict(self):
states = []
for i, opt in enumerate( self.opts ):
states.append( opt.state_dict() )
return states
def load_state_dict(self, state_dict):
for opt, state in zip( self.opts, state_dict ):
opt.load_state_dict( state )
try:
from .ext.apollo import Apollo
except Exception as e:
_logger.warning(f'Error while importing APOLLO: {str(e)}')
pass
def compile_model(model, backend="auto"):
if not backend or backend == "auto":
backend = AVAILABLE_COMPILE_BACKENDS[0]
if backend not in AVAILABLE_COMPILE_BACKENDS:
return torch.compile(model)
return torch.compile(model, backend=backend)
try:
from muon import Muon as Muon
except Exception as e:
raise e
#_logger.warning(f'Error while importing Muon: {str(e)}')
#pass
# https://github.com/konstmish/prodigy
try:
@ -154,4 +166,29 @@ def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False )
def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ):
return replace_embedding_old( model, klass, target, verbose )
Embedding.forward = autocast_forward(Embedding.forward)
Embedding.forward = autocast_forward(Embedding.forward)
AVAILABLE_COMPILE_BACKENDS = []
try:
AVAILABLE_COMPILE_BACKENDS += torch._dynamo.list_backends()
except Exception as e:
pass
def compile_model(model, backend="auto"):
if not backend or backend == "auto":
backend = AVAILABLE_COMPILE_BACKENDS[0]
if backend not in AVAILABLE_COMPILE_BACKENDS:
return torch.compile(model)
return torch.compile(model, backend=backend)
if cfg.optimizations.tensorrt:
try:
import torch_tensorrt
AVAILABLE_COMPILE_BACKENDS.append("tensorrt")
except Exception as e:
_logger.warning(f'Error while importing TensorRT: {str(e)}')
pass

View File

@ -33,7 +33,7 @@ from .distributed import (
from ..engines import Engine, Engines, TrainFeeder, default_feeder, load_engines
from .utils import to_device, do_gc, truncate_json
from ..utils import wrapper as ml
from ..utils import ml
from ..data import get_phone_symmap # should decouple from this trainer script
_logger = logging.getLogger(__name__)

View File

@ -53,7 +53,7 @@ def md5_hash( x ):
return hashlib.md5(str(x).encode("utf-8")).hexdigest()
# removes entries from a dict if that key is missing from the source
def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, return_missing=True ):
def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, return_missing=True, ignore=["optimizer_params"] ):
is_obj = hasattr( source, "__dict__" )
if parent_is_obj is None:
parent_is_obj = is_obj
@ -65,6 +65,9 @@ def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, retu
keep[k] = dest[k]
else:
missing.append(".".join(path + [k]))
if k in ignore:
continue
if recurse and isinstance( v, dict ):
keep[k], m = prune_missing( haystack[k], dest[k], path=path + [k], parent_is_obj=parent_is_obj, return_missing=return_missing )