vall-e/vall_e/models/__init__.py

42 lines
1001 B
Python
Raw Normal View History

2023-08-02 21:53:35 +00:00
def get_model(config, training=True):
name = config.name
2023-08-04 01:26:36 +00:00
if not config.experimental:
2024-06-04 05:07:00 +00:00
from .ar_nar import AR_NAR
2024-06-04 02:28:49 +00:00
model = AR_NAR(
n_text_tokens=config.text_tokens,
n_audio_tokens=config.audio_tokens,
d_model=config.dim,
n_heads=config.heads,
n_layers=config.layers,
n_experts=config.experts,
2024-06-04 02:28:49 +00:00
p_dropout=config.dropout,
2024-06-04 02:28:49 +00:00
l_padding = config.input_alignment,
2024-06-04 02:28:49 +00:00
training = training,
config = config,
2024-06-04 02:28:49 +00:00
)
else:
2024-06-04 05:07:00 +00:00
from .experimental import Model as Experimental
2024-06-04 02:28:49 +00:00
model = Experimental(
n_text_tokens=config.text_tokens,
n_audio_tokens=config.audio_tokens,
d_model=config.dim,
n_layers=config.layers,
n_heads=config.heads,
p_dropout=config.dropout,
2024-06-04 02:28:49 +00:00
config = config,
2024-06-04 02:28:49 +00:00
)
2023-08-02 21:53:35 +00:00
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
2023-08-02 21:53:35 +00:00
return model
def get_models(models, training=True):
return { model.full_name: get_model(model, training=training) for model in models }