vall-e/vall_e/models/__init__.py

28 lines
546 B
Python
Raw Normal View History

2023-08-02 21:53:35 +00:00
from .ar import AR
from .nar import NAR
2023-08-04 01:26:36 +00:00
def get_model(cfg):
if cfg.name == "ar":
2023-08-02 21:53:35 +00:00
Model = AR
2023-08-04 01:26:36 +00:00
elif cfg.name == "nar":
2023-08-02 21:53:35 +00:00
Model = NAR
else:
2023-08-04 01:26:36 +00:00
raise f"invalid model name: {cfg.name}"
name = cfg.name
2023-08-02 21:53:35 +00:00
model = Model(
2023-08-04 02:39:00 +00:00
n_tokens=cfg.tokens,
d_model=cfg.dim,
n_heads=cfg.heads,
n_layers=cfg.layers,
2023-09-04 02:27:13 +00:00
config = cfg
2023-08-02 21:53:35 +00:00
)
2023-08-04 01:26:36 +00:00
model._cfg = cfg
2023-08-02 21:53:35 +00:00
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
return model
def get_models(models):
return { model.full_name: get_model(model) for model in models }