From ab673e0426f9901c482615909cb4dc429bd26e23 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 3 Aug 2024 21:00:32 -0500 Subject: [PATCH] add cap for NAR-len training, to avoid any weird cases in early training where it'll just mess up and generate long lengths --- vall_e/train.py | 1 + vall_e/utils/trainer.py | 15 +++++++-------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/vall_e/train.py b/vall_e/train.py index e075f7b..7d3a48e 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -112,6 +112,7 @@ def run_eval(engines, eval_name, dl): resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"] ) elif "len" in engine.hyper_config.capabilities: len_list = engine(text_list=batch["text"], proms_list=batch["proms"], max_steps=10 ) # don't need more than that + len_list = [ min( l, cfg.evaluation.steps ) for l in len_list ] resps_list = engine( text_list=batch["text"], proms_list=batch["proms"], len_list=len_list, max_levels=cfg.evaluation.nar_levels ) else: if "ar" in engine.hyper_config.capabilities: diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 9c85a85..599ecda 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -127,14 +127,13 @@ def train( engines = load_engines() # validate if there's at least one model to train - if training: - found = False - for name, engine in engines.items(): - if engine.training: - found = True - break - if not found: - raise Exception('Training, but no model loaded set to train...') + found = False + for name, engine in engines.items(): + if engine.training: + found = True + break + if not found: + raise Exception('Training, but no model loaded set to train...') """ if is_local_leader():