Add model presets
This commit is contained in:
parent
ae029c1d75
commit
3a4d5be18b
@ -37,11 +37,12 @@ class Config(ConfigBase):
|
||||
eval_every: int = 2_000
|
||||
save_ckpt_every: int = 10_000
|
||||
|
||||
model: str = "ar"
|
||||
d_model: int = 512
|
||||
n_heads: int = 8
|
||||
n_layers: int = 12
|
||||
p_dropout: float = 0.1
|
||||
model: str = "ar-quarter"
|
||||
spkr_name_getter: str = "lambda p: p.parts[-2]"
|
||||
|
||||
@cached_property
|
||||
def get_spkr(self):
|
||||
return eval(self.spkr_name_getter)
|
||||
|
||||
@property
|
||||
def ds_cfg(self):
|
||||
|
||||
@ -9,30 +9,13 @@ from .config import cfg
|
||||
from .data import create_train_val_dataloader
|
||||
from .emb import qnt
|
||||
from .utils import setup_logging, to_device, trainer
|
||||
from .vall_e import AR, NAR
|
||||
from .vall_e import get_model
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_engines():
|
||||
if cfg.model.lower() == "ar":
|
||||
model = AR(
|
||||
cfg.num_tokens,
|
||||
cfg.d_model,
|
||||
cfg.n_heads,
|
||||
cfg.n_layers,
|
||||
cfg.p_dropout,
|
||||
)
|
||||
elif cfg.model.lower() == "nar":
|
||||
model = NAR(
|
||||
cfg.num_tokens,
|
||||
cfg.d_model,
|
||||
cfg.n_heads,
|
||||
cfg.n_layers,
|
||||
cfg.p_dropout,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(cfg.model)
|
||||
model = get_model(cfg.model)
|
||||
|
||||
engines = dict(
|
||||
model=trainer.Engine(
|
||||
|
||||
@ -1,2 +1,40 @@
|
||||
from ..config import cfg
|
||||
from .ar import AR
|
||||
from .nar import NAR
|
||||
|
||||
|
||||
def get_model(name):
|
||||
name = name.lower()
|
||||
|
||||
if name.startswith("ar"):
|
||||
Model = AR
|
||||
elif str().startswith("nar"):
|
||||
Model = NAR
|
||||
else:
|
||||
raise ValueError("Model name should start with AR or NAR.")
|
||||
|
||||
if "-quarter" in name:
|
||||
model = Model(
|
||||
cfg.num_tokens,
|
||||
d_model=256,
|
||||
n_heads=4,
|
||||
n_layers=12,
|
||||
)
|
||||
elif "-half" in name:
|
||||
model = Model(
|
||||
cfg.num_tokens,
|
||||
d_model=512,
|
||||
n_heads=8,
|
||||
n_layers=12,
|
||||
)
|
||||
elif "-official" in name:
|
||||
model = Model(
|
||||
cfg.num_tokens,
|
||||
d_model=1024,
|
||||
n_heads=16,
|
||||
n_layers=12,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(name)
|
||||
|
||||
return model
|
||||
|
||||
Loading…
Reference in New Issue
Block a user