from .base import Model def get_model(cfg, training=False): name = cfg.name model = Model( n_tokens=cfg.tokens, n_len=cfg.len, d_model=cfg.dim, d_resnet=cfg.resnet, ) model.config = cfg print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") return model def get_models(models, training=False): return { model.full_name: get_model(model, training=training) for model in models }