vall-e/vall_e/models/__init__.py
2023-08-02 22:06:39 +00:00

25 lines
530 B
Python
Executable File

from .ar import AR
from .nar import NAR
def get_model(model):
if model.name == "ar":
Model = AR
elif model.name == "nar":
Model = NAR
else:
raise f"invalid model name: {model.name}"
name = model.name
model = Model(
n_tokens=model.tokens,
d_model=model.dim,
n_heads=model.heads,
n_layers=model.layers,
)
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
return model
def get_models(models):
return { model.full_name: get_model(model) for model in models }