2023-08-02 21:53:35 +00:00
|
|
|
from .ar import AR
|
|
|
|
from .nar import NAR
|
2023-09-06 23:58:35 +00:00
|
|
|
from .ar_nar import AR_NAR
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2024-02-01 03:48:36 +00:00
|
|
|
def get_model(cfg, training=True):
|
2023-08-04 01:26:36 +00:00
|
|
|
if cfg.name == "ar":
|
2023-08-02 21:53:35 +00:00
|
|
|
Model = AR
|
2023-08-04 01:26:36 +00:00
|
|
|
elif cfg.name == "nar":
|
2023-08-02 21:53:35 +00:00
|
|
|
Model = NAR
|
2023-09-06 23:58:35 +00:00
|
|
|
elif cfg.name == "ar+nar":
|
|
|
|
Model = AR_NAR
|
2023-08-02 21:53:35 +00:00
|
|
|
else:
|
2023-08-04 01:26:36 +00:00
|
|
|
raise f"invalid model name: {cfg.name}"
|
|
|
|
name = cfg.name
|
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
model = Model(
|
2023-08-04 02:39:00 +00:00
|
|
|
n_tokens=cfg.tokens,
|
|
|
|
d_model=cfg.dim,
|
|
|
|
n_heads=cfg.heads,
|
|
|
|
n_layers=cfg.layers,
|
2023-12-23 22:08:17 +00:00
|
|
|
n_experts=cfg.experts,
|
2023-09-05 20:38:21 +00:00
|
|
|
|
2024-04-09 01:14:51 +00:00
|
|
|
l_padding = cfg.input_alignment,
|
|
|
|
|
|
|
|
training = training,
|
2023-09-05 20:38:21 +00:00
|
|
|
config = cfg,
|
2023-08-02 21:53:35 +00:00
|
|
|
)
|
2023-08-04 01:26:36 +00:00
|
|
|
model._cfg = cfg
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2023-10-13 03:21:43 +00:00
|
|
|
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
return model
|
|
|
|
|
2024-02-01 03:48:36 +00:00
|
|
|
def get_models(models, training=True):
|
|
|
|
return { model.full_name: get_model(model, training=training) for model in models }
|