reticulating splines

This commit is contained in:
mrq 2023-08-03 21:39:00 -05:00
parent 608c1970eb
commit 0a524f1d59
3 changed files with 7 additions and 8 deletions

View File

@ -5,5 +5,4 @@ if cfg.trainer.backend == "deepspeed":
elif cfg.trainer.backend == "local":
from .base import Engine
from .base import Engines
from .base import TrainFeeder
from .base import Engines, TrainFeeder, default_feeder

View File

@ -9,7 +9,7 @@ class TrainFeeder(Protocol):
) -> None | tuple[Tensor, Stats]:
...
def default_feed(engine, batch):
def default_feeder(engine, batch):
if isinstance(batch, list):
engine( *batch )
elif isinstance(batch, dict):
@ -257,7 +257,7 @@ class Engines(dict[str, Engine]):
stats.update(flatten_dict({ name.split("-")[0]: stat }))
return stats
def step(self, batch, feeder: TrainFeeder = default_feed, device=torch.cuda.current_device()):
def step(self, batch, feeder: TrainFeeder = default_feeder, device=torch.cuda.current_device()):
total_elapsed_time = 0
stats: Any = dict()

View File

@ -11,10 +11,10 @@ def get_model(cfg):
name = cfg.name
model = Model(
n_tokens=model.tokens,
d_model=model.dim,
n_heads=model.heads,
n_layers=model.layers,
n_tokens=cfg.tokens,
d_model=cfg.dim,
n_heads=cfg.heads,
n_layers=cfg.layers,
)
model._cfg = cfg