2024-08-29 18:27:16 +00:00
|
|
|
import logging
|
|
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2024-08-27 00:33:51 +00:00
|
|
|
def get_model(config, training=True, **model_kwargs):
|
2024-06-06 14:48:43 +00:00
|
|
|
name = config.name
|
2023-08-04 01:26:36 +00:00
|
|
|
|
2024-06-12 04:59:28 +00:00
|
|
|
if "len" in config.capabilities:
|
|
|
|
from .nar import NAR
|
|
|
|
model = NAR(
|
2024-06-06 14:48:43 +00:00
|
|
|
n_text_tokens=config.text_tokens,
|
|
|
|
n_audio_tokens=config.audio_tokens,
|
|
|
|
d_model=config.dim,
|
|
|
|
n_heads=config.heads,
|
|
|
|
n_layers=config.layers,
|
|
|
|
n_experts=config.experts,
|
2024-06-04 02:28:49 +00:00
|
|
|
|
2024-06-06 14:48:43 +00:00
|
|
|
p_dropout=config.dropout,
|
2024-06-04 02:28:49 +00:00
|
|
|
|
2024-06-06 14:48:43 +00:00
|
|
|
l_padding = config.input_alignment,
|
2024-06-04 02:28:49 +00:00
|
|
|
|
2024-06-09 01:30:15 +00:00
|
|
|
training = training,
|
|
|
|
config = config,
|
2024-08-27 00:33:51 +00:00
|
|
|
**model_kwargs
|
2024-06-09 01:30:15 +00:00
|
|
|
)
|
2024-06-30 15:37:33 +00:00
|
|
|
elif config.experimental.hf:
|
2024-06-12 04:59:28 +00:00
|
|
|
from .experimental import Model as Experimental
|
|
|
|
model = Experimental(
|
2024-06-09 01:30:15 +00:00
|
|
|
n_text_tokens=config.text_tokens,
|
|
|
|
n_audio_tokens=config.audio_tokens,
|
2024-06-12 04:59:28 +00:00
|
|
|
|
2024-06-09 01:30:15 +00:00
|
|
|
d_model=config.dim,
|
|
|
|
n_layers=config.layers,
|
2024-06-12 04:59:28 +00:00
|
|
|
n_heads=config.heads,
|
2024-06-09 01:30:15 +00:00
|
|
|
p_dropout=config.dropout,
|
2024-06-12 04:59:28 +00:00
|
|
|
|
2024-06-06 14:48:43 +00:00
|
|
|
config = config,
|
2024-08-27 00:33:51 +00:00
|
|
|
**model_kwargs
|
2024-06-04 02:28:49 +00:00
|
|
|
)
|
|
|
|
else:
|
2024-06-12 04:59:28 +00:00
|
|
|
from .ar_nar import AR_NAR
|
|
|
|
model = AR_NAR(
|
2024-06-06 14:48:43 +00:00
|
|
|
n_text_tokens=config.text_tokens,
|
|
|
|
n_audio_tokens=config.audio_tokens,
|
|
|
|
d_model=config.dim,
|
|
|
|
n_heads=config.heads,
|
2024-06-12 04:59:28 +00:00
|
|
|
n_layers=config.layers,
|
|
|
|
n_experts=config.experts,
|
|
|
|
|
2024-06-06 14:48:43 +00:00
|
|
|
p_dropout=config.dropout,
|
2024-06-12 04:59:28 +00:00
|
|
|
|
|
|
|
l_padding = config.input_alignment,
|
|
|
|
|
|
|
|
training = training,
|
2024-06-06 14:48:43 +00:00
|
|
|
config = config,
|
2024-08-27 00:33:51 +00:00
|
|
|
**model_kwargs
|
2024-06-04 02:28:49 +00:00
|
|
|
)
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2024-08-29 18:27:16 +00:00
|
|
|
_logger.info(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
return model
|
|
|
|
|
2024-08-27 00:33:51 +00:00
|
|
|
def get_models(models, training=True, **model_kwargs):
|
|
|
|
return { model.full_name: get_model(model, training=training, **model_kwargs) for model in models }
|