vall-e/vall_e/models/__init__.py

28 lines
608 B
Python
Raw Normal View History

from .ar_nar import AR_NAR
2023-08-02 21:53:35 +00:00
def get_model(cfg, training=True):
2023-08-04 01:26:36 +00:00
name = cfg.name
model = AR_NAR(
2023-08-04 02:39:00 +00:00
n_tokens=cfg.tokens,
d_model=cfg.dim,
n_heads=cfg.heads,
n_layers=cfg.layers,
n_experts=cfg.experts,
2024-05-11 21:31:05 +00:00
p_dropout=cfg.dropout,
l_padding = cfg.input_alignment,
training = training,
config = cfg,
2023-08-02 21:53:35 +00:00
)
2023-08-04 01:26:36 +00:00
model._cfg = cfg
2023-08-02 21:53:35 +00:00
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
2023-08-02 21:53:35 +00:00
return model
def get_models(models, training=True):
return { model.full_name: get_model(model, training=training) for model in models }