From 039482a48e11bdf91682275c7d50fa9ff8d3919a Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 26 Sep 2024 18:56:57 -0500 Subject: [PATCH] don't do eval on stt because it's so slow and I don't even bother doing any metrics against it anyways (to-do: make this a flag) --- vall_e/demo.py | 51 +++++++++++++++++++++++++------------------------ vall_e/train.py | 2 +- 2 files changed, 27 insertions(+), 26 deletions(-) diff --git a/vall_e/demo.py b/vall_e/demo.py index 71d73f6..0b5f33b 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -43,6 +43,7 @@ def main(): parser.add_argument("--demo-dir", type=Path, default=None) parser.add_argument("--skip-existing", action="store_true") parser.add_argument("--sample-from-dataset", action="store_true") + parser.add_argument("--load-from-dataloader", action="store_true") parser.add_argument("--dataset-samples", type=int, default=0) parser.add_argument("--audio-path-root", type=str, default=None) parser.add_argument("--preamble", type=str, default=None) @@ -120,40 +121,40 @@ def main(): samples_dirs["dataset"] = args.demo_dir / "dataset" - """ - _logger.info("Loading dataloader...") - dataloader = create_train_dataloader() - _logger.info("Loaded dataloader.") + if args.load_from_dataloader: + _logger.info("Loading dataloader...") + dataloader = create_train_dataloader() + _logger.info("Loaded dataloader.") - num = args.dataset_samples if args.dataset_samples else cfg.evaluation.size + num = args.dataset_samples if args.dataset_samples else cfg.evaluation.size - length = len( dataloader.dataset ) - for i in trange( num, desc="Sampling dataset for samples" ): - idx = random.randint( 0, length ) - batch = dataloader.dataset[idx] + length = len( dataloader.dataset ) + for i in trange( num, desc="Sampling dataset for samples" ): + idx = random.randint( 0, length ) + batch = dataloader.dataset[idx] - dir = args.demo_dir / "dataset" / f'{i}' + dir = args.demo_dir / "dataset" / f'{i}' - (dir / "out").mkdir(parents=True, exist_ok=True) + (dir / "out").mkdir(parents=True, exist_ok=True) - metadata = batch["metadata"] + metadata = batch["metadata"] - text = metadata["text"] - language = metadata["language"] - - prompt = dir / "prompt.wav" - reference = dir / "reference.wav" - out_path = dir / "out" / "ours.wav" + text = metadata["text"] + language = metadata["language"] + + prompt = dir / "prompt.wav" + reference = dir / "reference.wav" + out_path = dir / "out" / "ours.wav" - if args.skip_existing and out_path.exists(): - continue + if args.skip_existing and out_path.exists(): + continue - open( dir / "prompt.txt", "w", encoding="utf-8" ).write( text ) - open( dir / "language.txt", "w", encoding="utf-8" ).write( language ) + open( dir / "prompt.txt", "w", encoding="utf-8" ).write( text ) + open( dir / "language.txt", "w", encoding="utf-8" ).write( language ) + + decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" ) + decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" ) - decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" ) - decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" ) - """ for k, sample_dir in samples_dirs.items(): if not sample_dir.exists(): continue diff --git a/vall_e/train.py b/vall_e/train.py index 7c7cd29..58811be 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -115,7 +115,7 @@ def run_eval(engines, eval_name, dl): for i, task in enumerate( batch["task"] ): # easier to just change it to a tts task than drop stt tasks from the batch if task == "stt": - has_stt = True + # has_stt = True batch["task"][i] = "tts" batch["proms"][i] = batch["resps"][i][:75*3, :]