25 lines
530 B
Python
Executable File
25 lines
530 B
Python
Executable File
from .ar import AR
|
|
from .nar import NAR
|
|
|
|
def get_model(model):
|
|
if model.name == "ar":
|
|
Model = AR
|
|
elif model.name == "nar":
|
|
Model = NAR
|
|
else:
|
|
raise f"invalid model name: {model.name}"
|
|
name = model.name
|
|
model = Model(
|
|
n_tokens=model.tokens,
|
|
d_model=model.dim,
|
|
n_heads=model.heads,
|
|
n_layers=model.layers,
|
|
)
|
|
|
|
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
|
|
|
return model
|
|
|
|
def get_models(models):
|
|
return { model.full_name: get_model(model) for model in models }
|