From 6634d075764c4bdbbb670b0e1fd14acc5743f845 Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 23 Feb 2025 11:22:13 -0600 Subject: [PATCH] added muon optimizer through kludge hacks because it necessitates a second optimizer in tandum that seems to only sometimes work with deepspeed --- vall_e/engines/__init__.py | 20 +++++-- vall_e/engines/base.py | 2 +- vall_e/engines/deepspeed.py | 2 +- vall_e/inference.py | 2 +- vall_e/models/ar_nar.py | 46 ++++++++-------- vall_e/models/arch/transformer.py | 2 +- vall_e/models/base.py | 2 +- vall_e/utils/{wrapper.py => ml.py} | 87 +++++++++++++++++++++--------- vall_e/utils/trainer.py | 2 +- vall_e/utils/utils.py | 5 +- 10 files changed, 110 insertions(+), 60 deletions(-) rename vall_e/utils/{wrapper.py => ml.py} (78%) diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 133b7f1..5056bda 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -11,7 +11,7 @@ elif cfg.trainer.backend == "local": from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine from ..models import get_models, get_model -from ..utils import wrapper as ml +from ..utils import ml from ..utils.io import torch_save, torch_load, pick_path from ..models.lora import apply_lora, lora_load_state_dict @@ -114,7 +114,6 @@ def load_engines(training=True, **model_kwargs): "lr": cfg.hyperparameters.learning_rate, } - if cfg.hyperparameters.optimizer.lower() == "adamw": params["betas"] = (0.9, 0.96) params["eps"] = 1e-07 @@ -149,8 +148,21 @@ def load_engines(training=True, **model_kwargs): else: raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}') + muon_params = cfg.hyperparameters.optimizer_params.pop("muon", None) params.update(cfg.hyperparameters.optimizer_params) - optimizer = optimizer_class(**params) + + if muon_params is not None: + muon_params["params"] = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 and f'model.{name}' not in model.config.frozen_params ] + + params["params"] = [ param for name, param in model.model.named_parameters() if param.ndim < 2 and f'model.{name}' not in model.config.frozen_params ] + params["params"] += [ param for name, param in model.named_parameters() if not name.startswith('model.') and name not in model.config.frozen_params ] + + optimizer = ml.Optimizers([ + ml.Muon(**muon_params), + optimizer_class(**params), + ]) + else: + optimizer = optimizer_class(**params) if cfg.hyperparameters.scheduler.lower() == "schedulefree": if cfg.hyperparameters.optimizer.lower() == "adamw": @@ -406,7 +418,7 @@ def load_engines(training=True, **model_kwargs): try: engine.wandb = wandb.init(project=key_name, **kwargs) engine.wandb.watch(engine.module) - except Exception as e: + except Exception as e: engine.wandb = None else: engine.wandb = None diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 880c6aa..9f134c0 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -45,7 +45,7 @@ from typing import Any, Protocol from functools import cached_property from .base import TrainFeeder -from ..utils import wrapper as ml +from ..utils import ml _logger = logging.getLogger(__name__) diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py index f31d500..de79fd6 100755 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -25,7 +25,7 @@ from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distr from deepspeed.accelerator import get_accelerator from ..utils.distributed import init_distributed, distributed_initialized -from ..utils import wrapper as ml +from ..utils import ml from ..models.lora import freeze_non_lora_weights diff --git a/vall_e/inference.py b/vall_e/inference.py index 1f2e3c0..3e63fde 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -17,7 +17,7 @@ from .emb import g2p, qnt from .emb.qnt import trim, trim_random, unload_model, repeat_extend_audio from .emb.transcribe import transcribe -from .utils import to_device, set_seed, clamp, wrapper as ml +from .utils import to_device, set_seed, clamp, ml from .config import cfg, Config from .models import get_models diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 8b98aff..0db8af5 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -1189,8 +1189,13 @@ class AR_NAR(Base): break for i, l in enumerate( sequence_list ): - index = (l == audio_stop_token).nonzero()[:, 0].min() - sequence_list[i] = sequence_list[i][:index] + index = (l == audio_stop_token).nonzero() + # kludge for when it doesnt actually hit a stop token but i cant be bothered to properly address it right now since it only came up in test training at the moment + try: + index = index[:, 0].min() + sequence_list[i] = sequence_list[i][:index] + except Exception as e: + pass return sequence_list @@ -1362,8 +1367,6 @@ class AR_NAR(Base): def example_usage(): cfg.device = "cuda" cfg.trainer.backend = "local" - if cfg.audio_backend == "dac": - cfg.sample_rate = 44_100 from functools import partial from einops import repeat @@ -1372,7 +1375,7 @@ def example_usage(): from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio from ..data import _load_artifact from ..engines import Engine, Engines - from ..utils import wrapper as ml + from ..utils import ml from ..utils import setup_logging import numpy as np @@ -1403,7 +1406,7 @@ def example_usage(): 'n_text_tokens': cfg.model.text_tokens, 'n_audio_tokens': cfg.model.audio_tokens, - 'd_model': 1536, # 256, # 1024, # 1536 + 'd_model': 1024, # 256, # 1024, # 1536 'n_heads': 16, # 4, # 16, # 24 'n_layers': 12, # 32 'n_experts': 1 if not cfg.model else cfg.model.experts, @@ -1425,7 +1428,7 @@ def example_usage(): available_tasks = ["tts-nar"] model = AR_NAR(**kwargs).to(cfg.device) - steps = 250 // batch_size + steps = 100 // batch_size optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" @@ -1464,28 +1467,23 @@ def example_usage(): learning_rate = 0.01 optimizer = ml.Apollo - - """ - target_params = [] - target_modules_list = ["attn", "mlp"] - for module_name, module in model.named_modules(): - if not (isinstance(module, torch.nn.Linear)): - continue - if not any(target_key in module_name for target_key in target_modules_list): - continue - target_params.append(module.weight) - - param_ids = [id(p) for p in target_params] - regular_params = [p for p in model.parameters() if id(p) not in param_ids] - params = [{'params': regular_params}, {'params': target_params, 'rank': 1, 'proj': 'random', 'scale_type': 'tensor', 'scale': 128,'update_proj_gap': 200, 'proj_type': 'std'}] - """ params = [{'params': params, 'rank': 1, 'proj': 'random', 'scale_type': 'tensor', 'scale': 128,'update_proj_gap': 200, 'proj_type': 'std'}] else: raise ValueError(f"Unrecognized optimizer: {optimizer}") _logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}") - - optimizer = optimizer(params, lr=learning_rate) + + muon_params = cfg.hyperparameters.optimizer_params.pop("muon", None) + if muon_params is not None: + muon_params["params"] = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 ] + adam_params = [ param for name, param in model.model.named_parameters() if param.ndim < 2 ] + [ param for name, param in model.named_parameters() if not name.startswith('model.') ] + + optimizer = ml.Optimizers([ + ml.Muon(**muon_params), + optimizer(adam_params, lr=learning_rate) + ]) + else: + optimizer = optimizer(params, lr=learning_rate) if scheduler == "schedulefree": if isinstance(optimizer, ml.AdamW): diff --git a/vall_e/models/arch/transformer.py b/vall_e/models/arch/transformer.py index fe34f88..0e10fa4 100755 --- a/vall_e/models/arch/transformer.py +++ b/vall_e/models/arch/transformer.py @@ -14,7 +14,7 @@ from einops import rearrange from torch import Tensor, einsum, nn from torch.utils.checkpoint import checkpoint -from ...utils import wrapper as ml +from ...utils import ml class AdaLN(nn.Module): def __init__(self, d_model, n_levels, eps=1e-5, k=0.1, c=2): diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 21ccdcf..aa1a5ba 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -30,7 +30,7 @@ from torch.utils.checkpoint import checkpoint from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision from .arch import * -from ..utils import wrapper as ml, clamp +from ..utils import ml, clamp from ..samplers import * # yuck, kind of needed diff --git a/vall_e/utils/wrapper.py b/vall_e/utils/ml.py similarity index 78% rename from vall_e/utils/wrapper.py rename to vall_e/utils/ml.py index 5c86ede..a84045c 100755 --- a/vall_e/utils/wrapper.py +++ b/vall_e/utils/ml.py @@ -85,22 +85,6 @@ if cfg.optimizations.injects: torch.optim.AdamW = AdamW torch.optim.SGD = SGD -AVAILABLE_COMPILE_BACKENDS = [] - -try: - AVAILABLE_COMPILE_BACKENDS += torch._dynamo.list_backends() -except Exception as e: - pass - - -if cfg.optimizations.tensorrt: - try: - import torch_tensorrt - AVAILABLE_COMPILE_BACKENDS.append("tensorrt") - except Exception as e: - _logger.warning(f'Error while importing TensorRT: {str(e)}') - pass - if cfg.optimizations.unsloth: try: from .ext.unsloth import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch @@ -109,20 +93,48 @@ if cfg.optimizations.unsloth: _logger.warning(f'Error while importing Unsloth: {str(e)}') pass +class Optimizers(torch.optim.Optimizer): + def __init__(self, opts): + self.opts = opts + + def step(self, *args, **kwargs): + for opt in self.opts: + opt.step(*args, **kwargs) + + def zero_grad(self, *args, **kwargs): + for opt in self.opts: + opt.zero_grad(*args, **kwargs) + + @property + def param_groups(self): + l = [] + for opt in self.opts: + l += opt.param_groups + return l + + def state_dict(self): + states = [] + for i, opt in enumerate( self.opts ): + states.append( opt.state_dict() ) + + return states + + def load_state_dict(self, state_dict): + for opt, state in zip( self.opts, state_dict ): + opt.load_state_dict( state ) + try: from .ext.apollo import Apollo except Exception as e: _logger.warning(f'Error while importing APOLLO: {str(e)}') pass -def compile_model(model, backend="auto"): - if not backend or backend == "auto": - backend = AVAILABLE_COMPILE_BACKENDS[0] - - if backend not in AVAILABLE_COMPILE_BACKENDS: - return torch.compile(model) - - return torch.compile(model, backend=backend) +try: + from muon import Muon as Muon +except Exception as e: + raise e + #_logger.warning(f'Error while importing Muon: {str(e)}') + #pass # https://github.com/konstmish/prodigy try: @@ -154,4 +166,29 @@ def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ) def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ): return replace_embedding_old( model, klass, target, verbose ) -Embedding.forward = autocast_forward(Embedding.forward) \ No newline at end of file +Embedding.forward = autocast_forward(Embedding.forward) + +AVAILABLE_COMPILE_BACKENDS = [] + +try: + AVAILABLE_COMPILE_BACKENDS += torch._dynamo.list_backends() +except Exception as e: + pass + +def compile_model(model, backend="auto"): + if not backend or backend == "auto": + backend = AVAILABLE_COMPILE_BACKENDS[0] + + if backend not in AVAILABLE_COMPILE_BACKENDS: + return torch.compile(model) + + return torch.compile(model, backend=backend) + + +if cfg.optimizations.tensorrt: + try: + import torch_tensorrt + AVAILABLE_COMPILE_BACKENDS.append("tensorrt") + except Exception as e: + _logger.warning(f'Error while importing TensorRT: {str(e)}') + pass \ No newline at end of file diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index dd955df..90e0854 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -33,7 +33,7 @@ from .distributed import ( from ..engines import Engine, Engines, TrainFeeder, default_feeder, load_engines from .utils import to_device, do_gc, truncate_json -from ..utils import wrapper as ml +from ..utils import ml from ..data import get_phone_symmap # should decouple from this trainer script _logger = logging.getLogger(__name__) diff --git a/vall_e/utils/utils.py b/vall_e/utils/utils.py index 6c82ac2..a93cfe9 100755 --- a/vall_e/utils/utils.py +++ b/vall_e/utils/utils.py @@ -53,7 +53,7 @@ def md5_hash( x ): return hashlib.md5(str(x).encode("utf-8")).hexdigest() # removes entries from a dict if that key is missing from the source -def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, return_missing=True ): +def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, return_missing=True, ignore=["optimizer_params"] ): is_obj = hasattr( source, "__dict__" ) if parent_is_obj is None: parent_is_obj = is_obj @@ -65,6 +65,9 @@ def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, retu keep[k] = dest[k] else: missing.append(".".join(path + [k])) + + if k in ignore: + continue if recurse and isinstance( v, dict ): keep[k], m = prune_missing( haystack[k], dest[k], path=path + [k], parent_is_obj=parent_is_obj, return_missing=return_missing )