diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 3112ff1..e8341a9 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -134,6 +134,8 @@ class AR_NAR(Base): tone_list: list[Tensor] | None = None, len_list: list[Tensor] | None = None, + training: bool | None = None, + max_steps: int = 1000, max_levels: int = 0, max_resp_context: int = -1, @@ -157,8 +159,11 @@ class AR_NAR(Base): n_levels_set = {r.shape[-1] for r in resps_list} n_levels = next(iter(n_levels_set)) + if training is None: + training = n_levels == self.n_resp_levels + # is training - if n_levels == self.n_resp_levels: + if training: # to-do: make this YAML configurable def sample_task(): return "tts" diff --git a/vall_e/train.py b/vall_e/train.py index ab7fbdc..5124880 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -64,6 +64,8 @@ def train_feeder(engine, batch): proms_list=batch["proms"], resps_list=batch["resps"], lang_list=batch["lang"], + + training=True, ) losses = engine.gather_attribute("loss")