resnet-classifier/image_classifier/models/__init__.py
2023-08-05 03:48:06 +00:00

19 lines
365 B
Python
Executable File

from .base import Model
def get_model(cfg):
name = cfg.name
model = Model(
n_tokens=cfg.tokens,
n_len=cfg.len,
d_model=cfg.dim,
)
model._cfg = cfg
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
return model
def get_models(models):
return { model.full_name: get_model(model) for model in models }