diff --git a/vall_e/train.py b/vall_e/train.py index 4bf06b1..a921ee3 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -99,6 +99,10 @@ def run_eval(engines, eval_name, dl): if task != "tts": filename = f"{filename}_{task}" + # flatten prom + if not isinstance(prom, torch.Tensor): + prom = torch.concat([ p for p in prom if isinstance(p, torch.Tensor) ]) + # to-do, refine the output dir to be sane-er ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav") hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav")