from .base import Model def get_model(cfg): name = cfg.name model = Model( n_tokens=cfg.tokens, n_len=cfg.len, d_model=cfg.dim, ) model._cfg = cfg print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") return model def get_models(models): return { model.full_name: get_model(model) for model in models }