diff --git a/vall_e/demo.py b/vall_e/demo.py
index 0b5f33b..48098f3 100644
--- a/vall_e/demo.py
+++ b/vall_e/demo.py
@@ -43,7 +43,7 @@ def main():
parser.add_argument("--demo-dir", type=Path, default=None)
parser.add_argument("--skip-existing", action="store_true")
parser.add_argument("--sample-from-dataset", action="store_true")
- parser.add_argument("--load-from-dataloader", action="store_true")
+ parser.add_argument("--skip-loading-dataloader", action="store_true")
parser.add_argument("--dataset-samples", type=int, default=0)
parser.add_argument("--audio-path-root", type=str, default=None)
parser.add_argument("--preamble", type=str, default=None)
@@ -89,7 +89,7 @@ def main():
if not args.preamble:
args.preamble = "
".join([
'Below are some samples from my VALL-E implementation: https://git.ecker.tech/mrq/vall-e/.',
- 'I do not consider these to be state of the art, as the model does not follow close to the prompt as I would like for general speakers.',
+ 'Unlike the original VALL-E demo page, I\'m placing emphasis on the input prompt, as the model adheres to it stronger than others.',
])
# read html template
@@ -115,45 +115,46 @@ def main():
"librispeech": args.demo_dir / "librispeech",
}
+ if (args.demo_dir / "dataset").exists():
+ samples_dirs["dataset"] = args.demo_dir / "dataset"
+
# pull from dataset samples
if args.sample_from_dataset:
cfg.dataset.cache = False
-
samples_dirs["dataset"] = args.demo_dir / "dataset"
- if args.load_from_dataloader:
- _logger.info("Loading dataloader...")
- dataloader = create_train_dataloader()
- _logger.info("Loaded dataloader.")
+ _logger.info("Loading dataloader...")
+ dataloader = create_train_dataloader()
+ _logger.info("Loaded dataloader.")
- num = args.dataset_samples if args.dataset_samples else cfg.evaluation.size
+ num = args.dataset_samples if args.dataset_samples else cfg.evaluation.size
- length = len( dataloader.dataset )
- for i in trange( num, desc="Sampling dataset for samples" ):
- idx = random.randint( 0, length )
- batch = dataloader.dataset[idx]
+ length = len( dataloader.dataset )
+ for i in trange( num, desc="Sampling dataset for samples" ):
+ idx = random.randint( 0, length )
+ batch = dataloader.dataset[idx]
- dir = args.demo_dir / "dataset" / f'{i}'
+ dir = args.demo_dir / "dataset" / f'{i}'
- (dir / "out").mkdir(parents=True, exist_ok=True)
+ (dir / "out").mkdir(parents=True, exist_ok=True)
- metadata = batch["metadata"]
+ metadata = batch["metadata"]
- text = metadata["text"]
- language = metadata["language"]
-
- prompt = dir / "prompt.wav"
- reference = dir / "reference.wav"
- out_path = dir / "out" / "ours.wav"
+ text = metadata["text"]
+ language = metadata["language"]
+
+ prompt = dir / "prompt.wav"
+ reference = dir / "reference.wav"
+ out_path = dir / "out" / "ours.wav"
- if args.skip_existing and out_path.exists():
- continue
+ if args.skip_existing and out_path.exists():
+ continue
- open( dir / "prompt.txt", "w", encoding="utf-8" ).write( text )
- open( dir / "language.txt", "w", encoding="utf-8" ).write( language )
+ open( dir / "prompt.txt", "w", encoding="utf-8" ).write( text )
+ open( dir / "language.txt", "w", encoding="utf-8" ).write( language )
- decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" )
- decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" )
+ decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" )
+ decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" )
for k, sample_dir in samples_dirs.items():
if not sample_dir.exists():
@@ -182,23 +183,26 @@ def main():
if args.skip_existing and out_path.exists():
continue
- tts.inference(
- text=text,
- references=[prompt],
- language=language,
- out_path=out_path,
- input_prompt_length=args.input_prompt_length,
- max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
- ar_temp=args.ar_temp, nar_temp=args.nar_temp,
- min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
- top_p=args.top_p, top_k=args.top_k,
- repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
- length_penalty=args.length_penalty,
- beam_width=args.beam_width,
- mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
- seed=args.seed,
- tqdm=False,
- )
+ try:
+ tts.inference(
+ text=text,
+ references=[prompt],
+ language=language,
+ out_path=out_path,
+ input_prompt_length=args.input_prompt_length,
+ max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
+ ar_temp=args.ar_temp, nar_temp=args.nar_temp,
+ min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
+ top_p=args.top_p, top_k=args.top_k,
+ repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
+ length_penalty=args.length_penalty,
+ beam_width=args.beam_width,
+ mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
+ seed=args.seed,
+ tqdm=False,
+ )
+ except Exception as e:
+ print(f'Error while processing {out_path}: {e}')
# collate entries into HTML
samples = [
diff --git a/vall_e/webui.py b/vall_e/webui.py
index 9ca3e05..1a9b55f 100644
--- a/vall_e/webui.py
+++ b/vall_e/webui.py
@@ -22,6 +22,7 @@ from .train import train
from .utils import get_devices, setup_logging
from .utils.io import json_read, json_stringify
from .emb.qnt import decode_to_wave
+from .data import get_lang_symmap
tts = None
@@ -100,6 +101,9 @@ def load_model( yaml, device, dtype, attention ):
def get_speakers():
return cfg.dataset.training
+def get_languages():
+ return get_lang_symmap().keys()
+
#@gradio_wrapper(inputs=layout["dataset"]["inputs"].keys())
def load_sample( speaker ):
metadata_path = cfg.metadata_dir / f'{speaker}.json'
@@ -158,7 +162,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--text", type=str, default=kwargs["text"])
parser.add_argument("--task", type=str, default="tts")
parser.add_argument("--references", type=str, default=kwargs["reference"])
- parser.add_argument("--language", type=str, default="en")
+ parser.add_argument("--language", type=str, default=kwargs["language"])
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second))
parser.add_argument("--max-nar-levels", type=int, default=0), # kwargs["max-nar-levels"])
@@ -231,7 +235,7 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser = argparse.ArgumentParser(allow_abbrev=False)
# I'm very sure I can procedurally generate this list
parser.add_argument("--references", type=str, default=kwargs["reference"])
- parser.add_argument("--language", type=str, default="en")
+ parser.add_argument("--language", type=str, default=kwargs["language"])
parser.add_argument("--max-ar-steps", type=int, default=0)
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
parser.add_argument("--min-ar-temp", type=float, default=kwargs["min-ar-temp"])
@@ -381,6 +385,7 @@ with ui:
layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
with gr.Row():
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
+ layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
with gr.Tab("Sampler Settings"):
with gr.Row():
layout["inference_tts"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
@@ -419,7 +424,7 @@ with ui:
layout["inference_stt"]["inputs"]["ar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
with gr.Row():
layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
-
+ layout["inference_stt"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
with gr.Tab("Sampler Settings"):
with gr.Row():
layout["inference_stt"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")