eval/validation fix for SpeechX tasks

This commit is contained in:
mrq 2024-07-19 09:16:37 -05:00
parent 28a674e0f1
commit 692d09f9c1

View File

@ -99,6 +99,10 @@ def run_eval(engines, eval_name, dl):
if task != "tts":
filename = f"{filename}_{task}"
# flatten prom
if not isinstance(prom, torch.Tensor):
prom = torch.concat([ p for p in prom if isinstance(p, torch.Tensor) ])
# to-do, refine the output dir to be sane-er
ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav")
hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav")