reticulating splines
This commit is contained in:
parent
608c1970eb
commit
0a524f1d59
|
@ -5,5 +5,4 @@ if cfg.trainer.backend == "deepspeed":
|
|||
elif cfg.trainer.backend == "local":
|
||||
from .base import Engine
|
||||
|
||||
from .base import Engines
|
||||
from .base import TrainFeeder
|
||||
from .base import Engines, TrainFeeder, default_feeder
|
|
@ -9,7 +9,7 @@ class TrainFeeder(Protocol):
|
|||
) -> None | tuple[Tensor, Stats]:
|
||||
...
|
||||
|
||||
def default_feed(engine, batch):
|
||||
def default_feeder(engine, batch):
|
||||
if isinstance(batch, list):
|
||||
engine( *batch )
|
||||
elif isinstance(batch, dict):
|
||||
|
@ -257,7 +257,7 @@ class Engines(dict[str, Engine]):
|
|||
stats.update(flatten_dict({ name.split("-")[0]: stat }))
|
||||
return stats
|
||||
|
||||
def step(self, batch, feeder: TrainFeeder = default_feed, device=torch.cuda.current_device()):
|
||||
def step(self, batch, feeder: TrainFeeder = default_feeder, device=torch.cuda.current_device()):
|
||||
total_elapsed_time = 0
|
||||
|
||||
stats: Any = dict()
|
||||
|
|
|
@ -11,10 +11,10 @@ def get_model(cfg):
|
|||
name = cfg.name
|
||||
|
||||
model = Model(
|
||||
n_tokens=model.tokens,
|
||||
d_model=model.dim,
|
||||
n_heads=model.heads,
|
||||
n_layers=model.layers,
|
||||
n_tokens=cfg.tokens,
|
||||
d_model=cfg.dim,
|
||||
n_heads=cfg.heads,
|
||||
n_layers=cfg.layers,
|
||||
)
|
||||
model._cfg = cfg
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user