From 0a524f1d59d3121e974b1c80cd4a7054e1e18bae Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 3 Aug 2023 21:39:00 -0500 Subject: [PATCH] reticulating splines --- vall_e/engines/__init__.py | 3 +-- vall_e/engines/base.py | 4 ++-- vall_e/models/__init__.py | 8 ++++---- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 3dcf775..cf59d93 100644 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -5,5 +5,4 @@ if cfg.trainer.backend == "deepspeed": elif cfg.trainer.backend == "local": from .base import Engine -from .base import Engines -from .base import TrainFeeder \ No newline at end of file +from .base import Engines, TrainFeeder, default_feeder \ No newline at end of file diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index df24da0..9a559b6 100644 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -9,7 +9,7 @@ class TrainFeeder(Protocol): ) -> None | tuple[Tensor, Stats]: ... -def default_feed(engine, batch): +def default_feeder(engine, batch): if isinstance(batch, list): engine( *batch ) elif isinstance(batch, dict): @@ -257,7 +257,7 @@ class Engines(dict[str, Engine]): stats.update(flatten_dict({ name.split("-")[0]: stat })) return stats - def step(self, batch, feeder: TrainFeeder = default_feed, device=torch.cuda.current_device()): + def step(self, batch, feeder: TrainFeeder = default_feeder, device=torch.cuda.current_device()): total_elapsed_time = 0 stats: Any = dict() diff --git a/vall_e/models/__init__.py b/vall_e/models/__init__.py index 93fe7a5..a1d4b46 100755 --- a/vall_e/models/__init__.py +++ b/vall_e/models/__init__.py @@ -11,10 +11,10 @@ def get_model(cfg): name = cfg.name model = Model( - n_tokens=model.tokens, - d_model=model.dim, - n_heads=model.heads, - n_layers=model.layers, + n_tokens=cfg.tokens, + d_model=cfg.dim, + n_heads=cfg.heads, + n_layers=cfg.layers, ) model._cfg = cfg