20 lines
441 B
Python
20 lines
441 B
Python
from .base import Model
|
|
|
|
def get_model(cfg, training=False):
|
|
name = cfg.name
|
|
|
|
model = Model(
|
|
n_tokens=cfg.tokens,
|
|
n_len=cfg.len,
|
|
d_model=cfg.dim,
|
|
d_resnet=cfg.resnet,
|
|
)
|
|
model.config = cfg
|
|
|
|
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
|
|
|
return model
|
|
|
|
def get_models(models, training=False):
|
|
return { model.full_name: get_model(model, training=training) for model in models }
|