diff --git a/vall_e/demo.py b/vall_e/demo.py index 0b5f33b..48098f3 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -43,7 +43,7 @@ def main(): parser.add_argument("--demo-dir", type=Path, default=None) parser.add_argument("--skip-existing", action="store_true") parser.add_argument("--sample-from-dataset", action="store_true") - parser.add_argument("--load-from-dataloader", action="store_true") + parser.add_argument("--skip-loading-dataloader", action="store_true") parser.add_argument("--dataset-samples", type=int, default=0) parser.add_argument("--audio-path-root", type=str, default=None) parser.add_argument("--preamble", type=str, default=None) @@ -89,7 +89,7 @@ def main(): if not args.preamble: args.preamble = "
".join([ 'Below are some samples from my VALL-E implementation: https://git.ecker.tech/mrq/vall-e/.', - 'I do not consider these to be state of the art, as the model does not follow close to the prompt as I would like for general speakers.', + 'Unlike the original VALL-E demo page, I\'m placing emphasis on the input prompt, as the model adheres to it stronger than others.', ]) # read html template @@ -115,45 +115,46 @@ def main(): "librispeech": args.demo_dir / "librispeech", } + if (args.demo_dir / "dataset").exists(): + samples_dirs["dataset"] = args.demo_dir / "dataset" + # pull from dataset samples if args.sample_from_dataset: cfg.dataset.cache = False - samples_dirs["dataset"] = args.demo_dir / "dataset" - if args.load_from_dataloader: - _logger.info("Loading dataloader...") - dataloader = create_train_dataloader() - _logger.info("Loaded dataloader.") + _logger.info("Loading dataloader...") + dataloader = create_train_dataloader() + _logger.info("Loaded dataloader.") - num = args.dataset_samples if args.dataset_samples else cfg.evaluation.size + num = args.dataset_samples if args.dataset_samples else cfg.evaluation.size - length = len( dataloader.dataset ) - for i in trange( num, desc="Sampling dataset for samples" ): - idx = random.randint( 0, length ) - batch = dataloader.dataset[idx] + length = len( dataloader.dataset ) + for i in trange( num, desc="Sampling dataset for samples" ): + idx = random.randint( 0, length ) + batch = dataloader.dataset[idx] - dir = args.demo_dir / "dataset" / f'{i}' + dir = args.demo_dir / "dataset" / f'{i}' - (dir / "out").mkdir(parents=True, exist_ok=True) + (dir / "out").mkdir(parents=True, exist_ok=True) - metadata = batch["metadata"] + metadata = batch["metadata"] - text = metadata["text"] - language = metadata["language"] - - prompt = dir / "prompt.wav" - reference = dir / "reference.wav" - out_path = dir / "out" / "ours.wav" + text = metadata["text"] + language = metadata["language"] + + prompt = dir / "prompt.wav" + reference = dir / "reference.wav" + out_path = dir / "out" / "ours.wav" - if args.skip_existing and out_path.exists(): - continue + if args.skip_existing and out_path.exists(): + continue - open( dir / "prompt.txt", "w", encoding="utf-8" ).write( text ) - open( dir / "language.txt", "w", encoding="utf-8" ).write( language ) + open( dir / "prompt.txt", "w", encoding="utf-8" ).write( text ) + open( dir / "language.txt", "w", encoding="utf-8" ).write( language ) - decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" ) - decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" ) + decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" ) + decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" ) for k, sample_dir in samples_dirs.items(): if not sample_dir.exists(): @@ -182,23 +183,26 @@ def main(): if args.skip_existing and out_path.exists(): continue - tts.inference( - text=text, - references=[prompt], - language=language, - out_path=out_path, - input_prompt_length=args.input_prompt_length, - max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels, - ar_temp=args.ar_temp, nar_temp=args.nar_temp, - min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp, - top_p=args.top_p, top_k=args.top_k, - repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, - length_penalty=args.length_penalty, - beam_width=args.beam_width, - mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, - seed=args.seed, - tqdm=False, - ) + try: + tts.inference( + text=text, + references=[prompt], + language=language, + out_path=out_path, + input_prompt_length=args.input_prompt_length, + max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels, + ar_temp=args.ar_temp, nar_temp=args.nar_temp, + min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp, + top_p=args.top_p, top_k=args.top_k, + repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, + length_penalty=args.length_penalty, + beam_width=args.beam_width, + mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, + seed=args.seed, + tqdm=False, + ) + except Exception as e: + print(f'Error while processing {out_path}: {e}') # collate entries into HTML samples = [ diff --git a/vall_e/webui.py b/vall_e/webui.py index 9ca3e05..1a9b55f 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -22,6 +22,7 @@ from .train import train from .utils import get_devices, setup_logging from .utils.io import json_read, json_stringify from .emb.qnt import decode_to_wave +from .data import get_lang_symmap tts = None @@ -100,6 +101,9 @@ def load_model( yaml, device, dtype, attention ): def get_speakers(): return cfg.dataset.training +def get_languages(): + return get_lang_symmap().keys() + #@gradio_wrapper(inputs=layout["dataset"]["inputs"].keys()) def load_sample( speaker ): metadata_path = cfg.metadata_dir / f'{speaker}.json' @@ -158,7 +162,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser.add_argument("--text", type=str, default=kwargs["text"]) parser.add_argument("--task", type=str, default="tts") parser.add_argument("--references", type=str, default=kwargs["reference"]) - parser.add_argument("--language", type=str, default="en") + parser.add_argument("--language", type=str, default=kwargs["language"]) parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"]) parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second)) parser.add_argument("--max-nar-levels", type=int, default=0), # kwargs["max-nar-levels"]) @@ -231,7 +235,7 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser = argparse.ArgumentParser(allow_abbrev=False) # I'm very sure I can procedurally generate this list parser.add_argument("--references", type=str, default=kwargs["reference"]) - parser.add_argument("--language", type=str, default="en") + parser.add_argument("--language", type=str, default=kwargs["language"]) parser.add_argument("--max-ar-steps", type=int, default=0) parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"]) parser.add_argument("--min-ar-temp", type=float, default=kwargs["min-ar-temp"]) @@ -381,6 +385,7 @@ with ui: layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)") with gr.Row(): layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") + layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en") with gr.Tab("Sampler Settings"): with gr.Row(): layout["inference_tts"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.") @@ -419,7 +424,7 @@ with ui: layout["inference_stt"]["inputs"]["ar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)") with gr.Row(): layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") - + layout["inference_stt"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en") with gr.Tab("Sampler Settings"): with gr.Row(): layout["inference_stt"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")