From 692d09f9c1e2e6f4ff4cab776eadf24d9d2ae548 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 19 Jul 2024 09:16:37 -0500 Subject: [PATCH] eval/validation fix for SpeechX tasks --- vall_e/train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vall_e/train.py b/vall_e/train.py index 4bf06b1..a921ee3 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -99,6 +99,10 @@ def run_eval(engines, eval_name, dl): if task != "tts": filename = f"{filename}_{task}" + # flatten prom + if not isinstance(prom, torch.Tensor): + prom = torch.concat([ p for p in prom if isinstance(p, torch.Tensor) ]) + # to-do, refine the output dir to be sane-er ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav") hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav")