diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 5056bda..7b30418 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -148,8 +148,9 @@ def load_engines(training=True, **model_kwargs): else: raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}') - muon_params = cfg.hyperparameters.optimizer_params.pop("muon", None) - params.update(cfg.hyperparameters.optimizer_params) + if cfg.hyperparameters.optimizer_params is not None: + muon_params = cfg.hyperparameters.optimizer_params.pop("muon", None) + params.update(cfg.hyperparameters.optimizer_params) if muon_params is not None: muon_params["params"] = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 and f'model.{name}' not in model.config.frozen_params ] diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 0db8af5..9d6f76e 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -1428,7 +1428,7 @@ def example_usage(): available_tasks = ["tts-nar"] model = AR_NAR(**kwargs).to(cfg.device) - steps = 100 // batch_size + steps = 250 // batch_size optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" @@ -1473,7 +1473,7 @@ def example_usage(): _logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}") - muon_params = cfg.hyperparameters.optimizer_params.pop("muon", None) + muon_params = cfg.hyperparameters.optimizer_params.pop("muon", None) if cfg.hyperparameters.optimizer_params is not None else None if muon_params is not None: muon_params["params"] = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 ] adam_params = [ param for name, param in model.model.named_parameters() if param.ndim < 2 ] + [ param for name, param in model.named_parameters() if not name.startswith('model.') ] diff --git a/vall_e/models/base.py b/vall_e/models/base.py index aa1a5ba..c4ec0c5 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -547,6 +547,7 @@ class Base(nn.Module): self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True self.stop_token = self.n_audio_tokens # id 1024 + self.mask_token = self.n_audio_tokens + 1 # id 1024 self.causal = "ar" in self.capabilities or "len" in self.capabilities self.version = self.config.version if self.config is not None else 5 self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0) @@ -715,7 +716,7 @@ class Base(nn.Module): if self.version >= 7: if monolithic_audio_encoder: self.audio_emb = AudioEncoder( - n_tokens=n_audio_tokens + 1, # masked token + n_tokens=n_audio_tokens + 2, # stop + masked token n_levels=self.n_resp_levels, token_dim=d_model, ) @@ -726,7 +727,7 @@ class Base(nn.Module): token_dim=d_model, ) self.resps_emb = AudioEncoder( - n_tokens=n_audio_tokens + 1, # masked token + n_tokens=n_audio_tokens + 2, # stop + masked token n_levels=self.n_resp_levels, token_dim=d_model, ) @@ -1310,9 +1311,9 @@ class Base(nn.Module): elif name == "resp": if self.version >= 7: if self.audio_emb is not None: - embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.stop_token ) + embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token ) else: - embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.stop_token ) + embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token ) # if training NAR-len RVQ level 0 elif dropout_mask is not None: embedding = self.resps_emb( diff --git a/vall_e/utils/ml.py b/vall_e/utils/ml.py index a84045c..fa4502f 100755 --- a/vall_e/utils/ml.py +++ b/vall_e/utils/ml.py @@ -132,9 +132,8 @@ except Exception as e: try: from muon import Muon as Muon except Exception as e: - raise e - #_logger.warning(f'Error while importing Muon: {str(e)}') - #pass + _logger.warning(f'Error while importing Muon: {str(e)}') + pass # https://github.com/konstmish/prodigy try: