Add model presets

This commit is contained in:
enhuiz 2023-01-12 19:45:50 +08:00
parent ae029c1d75
commit 3a4d5be18b
3 changed files with 46 additions and 24 deletions

View File

@ -37,11 +37,12 @@ class Config(ConfigBase):
eval_every: int = 2_000
save_ckpt_every: int = 10_000
model: str = "ar"
d_model: int = 512
n_heads: int = 8
n_layers: int = 12
p_dropout: float = 0.1
model: str = "ar-quarter"
spkr_name_getter: str = "lambda p: p.parts[-2]"
@cached_property
def get_spkr(self):
return eval(self.spkr_name_getter)
@property
def ds_cfg(self):

View File

@ -9,30 +9,13 @@ from .config import cfg
from .data import create_train_val_dataloader
from .emb import qnt
from .utils import setup_logging, to_device, trainer
from .vall_e import AR, NAR
from .vall_e import get_model
_logger = logging.getLogger(__name__)
def load_engines():
if cfg.model.lower() == "ar":
model = AR(
cfg.num_tokens,
cfg.d_model,
cfg.n_heads,
cfg.n_layers,
cfg.p_dropout,
)
elif cfg.model.lower() == "nar":
model = NAR(
cfg.num_tokens,
cfg.d_model,
cfg.n_heads,
cfg.n_layers,
cfg.p_dropout,
)
else:
raise NotImplementedError(cfg.model)
model = get_model(cfg.model)
engines = dict(
model=trainer.Engine(

View File

@ -1,2 +1,40 @@
from ..config import cfg
from .ar import AR
from .nar import NAR
def get_model(name):
name = name.lower()
if name.startswith("ar"):
Model = AR
elif str().startswith("nar"):
Model = NAR
else:
raise ValueError("Model name should start with AR or NAR.")
if "-quarter" in name:
model = Model(
cfg.num_tokens,
d_model=256,
n_heads=4,
n_layers=12,
)
elif "-half" in name:
model = Model(
cfg.num_tokens,
d_model=512,
n_heads=8,
n_layers=12,
)
elif "-official" in name:
model = Model(
cfg.num_tokens,
d_model=1024,
n_heads=16,
n_layers=12,
)
else:
raise NotImplementedError(name)
return model