diff --git a/vall_e/config.py b/vall_e/config.py index bf3d6c0..e455c05 100644 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -37,11 +37,12 @@ class Config(ConfigBase): eval_every: int = 2_000 save_ckpt_every: int = 10_000 - model: str = "ar" - d_model: int = 512 - n_heads: int = 8 - n_layers: int = 12 - p_dropout: float = 0.1 + model: str = "ar-quarter" + spkr_name_getter: str = "lambda p: p.parts[-2]" + + @cached_property + def get_spkr(self): + return eval(self.spkr_name_getter) @property def ds_cfg(self): diff --git a/vall_e/train.py b/vall_e/train.py index a5a721f..e262ed1 100644 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -9,30 +9,13 @@ from .config import cfg from .data import create_train_val_dataloader from .emb import qnt from .utils import setup_logging, to_device, trainer -from .vall_e import AR, NAR +from .vall_e import get_model _logger = logging.getLogger(__name__) def load_engines(): - if cfg.model.lower() == "ar": - model = AR( - cfg.num_tokens, - cfg.d_model, - cfg.n_heads, - cfg.n_layers, - cfg.p_dropout, - ) - elif cfg.model.lower() == "nar": - model = NAR( - cfg.num_tokens, - cfg.d_model, - cfg.n_heads, - cfg.n_layers, - cfg.p_dropout, - ) - else: - raise NotImplementedError(cfg.model) + model = get_model(cfg.model) engines = dict( model=trainer.Engine( diff --git a/vall_e/vall_e/__init__.py b/vall_e/vall_e/__init__.py index ab095f0..2bf8126 100644 --- a/vall_e/vall_e/__init__.py +++ b/vall_e/vall_e/__init__.py @@ -1,2 +1,40 @@ +from ..config import cfg from .ar import AR from .nar import NAR + + +def get_model(name): + name = name.lower() + + if name.startswith("ar"): + Model = AR + elif str().startswith("nar"): + Model = NAR + else: + raise ValueError("Model name should start with AR or NAR.") + + if "-quarter" in name: + model = Model( + cfg.num_tokens, + d_model=256, + n_heads=4, + n_layers=12, + ) + elif "-half" in name: + model = Model( + cfg.num_tokens, + d_model=512, + n_heads=8, + n_layers=12, + ) + elif "-official" in name: + model = Model( + cfg.num_tokens, + d_model=1024, + n_heads=16, + n_layers=12, + ) + else: + raise NotImplementedError(name) + + return model