diff --git a/vall_e/data.py b/vall_e/data.py index 4937c51..0f3143b 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -429,7 +429,7 @@ class Dataset(_Dataset): self.training = training self.dataset_type = "training" if self.training else "validation" self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation - self.sampler_type = cfg.dataset.sample_type if self.dataset_type == "training" else "group" + self.sampler_type = cfg.dataset.sample_type # if self.dataset_type == "training" else "group" # to-do: do not do validation if there's nothing in the validation # this just makes it be happy diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 5e7e0c1..d9f31ff 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -86,7 +86,7 @@ class Engine(): self._frozen_params.clear() @property - def training(self): + def _training(self): if not hasattr(self, "hyper_config"): return True return self.hyper_config.training @@ -321,7 +321,7 @@ class Engines(dict[str, Engine]): cfg.ckpt_dir.mkdir(parents=True, exist_ok=True) for name, engine in self.items(): - if not engine.training: + if not engine._training: continue save_dir = cfg.ckpt_dir / name diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py index 5ed93e4..08258ae 100755 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -77,7 +77,7 @@ class Engine(DeepSpeedEngine): self._frozen_params.clear() @property - def training(self): + def _training(self): return self.hyper_config.training @property diff --git a/vall_e/train.py b/vall_e/train.py index 535eb3a..2f4fe04 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -57,7 +57,7 @@ def train_feeder(engine, batch): else: engine( text_list=batch["text"], - proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level + proms_list=batch["proms"], resps_list=batch["resps"], lang_list=batch["lang"], ) diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index be32ab8..9d4e64b 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -157,9 +157,6 @@ def train( # Training loop for batch in _make_infinite_epochs(train_dl): - if not engine.training: - continue - if engines.global_step >= cfg.trainer.iterations: break