diff --git a/vall_e/config.py b/vall_e/config.py index 48b1a19..9ade194 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -167,6 +167,7 @@ class Dataset: prompt_similar_p: float = 0.75 # odds of sampling for a similar prompt instead of a random prompt prompt_similar_top_k: int = 1 # top-k similar candidates to sample from prompt_similar_top_k_offset: int = 0 # offset from the top-k to sample from + prompt_inject_noise: bool = False # adds noise to the input prompt waveform to try and vary things resps_max_samples: int = 1 # number of samples to target for training resps_append_p: float = 1.0 # probability to append another sample to the training target @@ -176,7 +177,6 @@ class Dataset: reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes reencode_device: str = "cpu" # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method noise_scale: float = 0.25 # scaling noise value - noise_inject_in_prom: bool = False # adds noise to the input prompt waveform to try and vary things retokenize_text: bool = False _frames_per_second: int = 0 # allows setting your own hint diff --git a/vall_e/data.py b/vall_e/data.py index 86e0ae4..4ba7593 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -1010,7 +1010,8 @@ class Dataset(_Dataset): """ prom_length = 0 - trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) if trim else 0 + duration_lo, duration_hi = cfg.dataset.prompt_duration_range + trim_length = int(random.uniform(duration_lo, duration_hi) * cfg.dataset.frames_per_second) if trim else 0 for _ in range(cfg.dataset.prompt_max_samples): if reference is not None: @@ -1142,7 +1143,7 @@ class Dataset(_Dataset): if task == "tts": proms = self.sample_prompts(spkr_name, reference=path) - if cfg.dataset.inject_noise_in_prom: + if cfg.dataset.prompt_inject_noise: # sample random noise noise = self.sample_noise() # extend the noise to fill the target audio @@ -1156,7 +1157,8 @@ class Dataset(_Dataset): elif task == "tts-c": # trim a piece of the output response if naive: - trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) + duration_lo, duration_hi = cfg.dataset.prompt_duration_range + trim_length = int(random.uniform(duration_lo, duration_hi) * cfg.dataset.frames_per_second) proms = resps[:trim_length, :] resps = resps[trim_length:, :] diff --git a/vall_e/demo.py b/vall_e/demo.py index 0c4d8b1..cef1077 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -87,7 +87,7 @@ def main(): parser.add_argument("--random-prompts", action="store_true") parser.add_argument("--lora", action="store_true") - parser.add_argument("--comparison", action="store_true") + parser.add_argument("--comparison", type=str, default=None) args = parser.parse_args() @@ -104,34 +104,47 @@ def main(): # comparison kwargs comparison_kwargs = { - "enabled": False, "titles": [], - "suffix": "_after", - "before": {}, - "after": {} + "suffix": "diff", + "enabled": {}, + "disabled": {} } if args.lora: - comparison_kwargs["enabled"] = True - comparison_kwargs["suffix"] = "_lora" - comparison_kwargs["titles"] = ["No LoRA", "LoRA"] - comparison_kwargs["before"]["use_lora"] = True - comparison_kwargs["after"]["use_lora"] = False - # to-do: make this user definable - elif args.comparison: - comparison_kwargs["enabled"] = True - comparison_kwargs["suffix"] = "_entropix" - comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"] - - comparison_kwargs["before"]["entropix_sampling"] = True - comparison_kwargs["before"]["ar_temp"] = 0.666 - comparison_kwargs["before"]["top_k"] = 27 - comparison_kwargs["before"]["top_p"] = 0.9 - comparison_kwargs["after"]["entropix_sampling"] = False - comparison_kwargs["after"]["ar_temp"] = args.ar_temp - comparison_kwargs["after"]["top_k"] = args.top_k - comparison_kwargs["after"]["top_p"] = args.top_p + args.comparison = "lora" + # to-do: just make this mappable + if args.comparison == "lora": + comparison_kwargs["suffix"] = "lora" + comparison_kwargs["titles"] = ["No LoRA", "LoRA"] + + comparison_kwargs["disabled"]["use_lora"] = False + comparison_kwargs["enabled"]["use_lora"] = True + elif args.comparison == "entropix-sampling": + comparison_kwargs["suffix"] = "entropix_sampling" + comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"] + comparison_kwargs["disabled"]["entropix_sampling"] = False + comparison_kwargs["disabled"]["ar_temp"] = args.ar_temp + comparison_kwargs["disabled"]["top_k"] = args.top_k + comparison_kwargs["disabled"]["top_p"] = args.top_p + comparison_kwargs["enabled"]["entropix_sampling"] = True + comparison_kwargs["enabled"]["ar_temp"] = 0.666 + comparison_kwargs["enabled"]["top_k"] = 27 + comparison_kwargs["enabled"]["top_p"] = 0.9 + elif args.comparison == "ar-temp": + comparison_kwargs["suffix"] = "temperature" + comparison_kwargs["titles"] = [f"Temp: {args.ar_temp:.2f}", "Temp: 1.0"] + + comparison_kwargs["disabled"]["ar_temp"] = args.ar_temp + comparison_kwargs["enabled"]["ar_temp"] = 1.0 + elif args.comparison == "input-prompt-length": + comparison_kwargs["suffix"] = "input_prompt_length" + comparison_kwargs["titles"] = [f"Prompt Length: {args.input_prompt_length:.2f}s", "Prompt Length: 6.0s"] + + comparison_kwargs["disabled"]["input-prompt-length"] = args.input_prompt_length + comparison_kwargs["enabled"]["input-prompt-length"] = 6.0 + else: + raise Exception(f"Unrecognized comparison flag: {args.comparison}") # read html template html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read() @@ -204,10 +217,9 @@ def main(): if not sample_dir.exists(): continue - speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ] - sources = [ "ms_valle", "yourtts" ] - samples = [] + speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ] + sources = [ "ms_valle", "f5" ] # generate demo output for dir in tqdm(speakers, desc=f"Generating demo for {k}"): @@ -217,20 +229,21 @@ def main(): reference = dir / "reference.wav" out_path = dir / "out" / "ours.wav" out_path_comparison = dir / "out" / f"ours_{comparison_kwargs["suffix"]}.wav" + external_sources = [ dir / "out" / f"{source}.wav" for source in sources ] - extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_comparison ] if comparison_kwargs["enabled"] else []) + audio_samples = [ prompt, out_path ] + if args.comparison: + audio_samples += [ out_path_comparison ] + audio_samples += [ p for p in external_sources if p.exists() ] if not args.random_prompts or k == "librispeech": - extra_sources += [ reference ] + audio_samples += [ reference ] samples.append(( text, - [ prompt, out_path ] + extra_sources, + audio_samples, )) - if args.skip_existing and out_path.exists(): - continue - seed = args.seed if args.seed else int(time.time()) kwargs = dict( @@ -253,19 +266,20 @@ def main(): ) def safe_inference( out_path=out_path ): + if args.skip_existing and out_path.exists(): + return try: tts.inference( out_path=out_path, **kwargs ) except Exception as e: print(f'Error while processing {out_path}: {e}') - if comparison_kwargs["enabled"]: - kwargs.update( comparison_kwargs["before"] ) + if args.comparison: + kwargs.update( comparison_kwargs["enabled"] ) safe_inference(out_path_comparison) - kwargs.update( comparison_kwargs["after"] ) + kwargs.update( comparison_kwargs["disabled"] ) safe_inference() - # collate entries into HTML samples = [ f'\n\t\t\t\n\t\t\t\t{text}'+ @@ -280,7 +294,7 @@ def main(): # write audio into template html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) ) - if comparison_kwargs["enabled"]: + if args.comparison: before, after = comparison_kwargs["titles"] if args.random_prompts: html = html.replace("Our VALL-E\n\t\t\t\t\tGround Truth", f"Our VALL-E ({before})\n\t\t\t\t\tOur VALL-E ({after})") diff --git a/vall_e/webui.py b/vall_e/webui.py index 8ce9285..f1b2de3 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -346,14 +346,15 @@ with ui: with gr.Row(): layout["inference_tts"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.") #layout["inference_tts"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.") - layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.") + layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Repeat/Trim Length", info="Repeats and trims the input prompt down to X seconds. Set 0 to disable.") with gr.Row(): layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.9, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)") layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)") with gr.Row(): #layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.") layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") - layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.") + if cfg.experimental: + layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.") layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en") with gr.Tab("Sampler Settings"): with gr.Row(): @@ -394,7 +395,8 @@ with ui: layout["inference_stt"]["inputs"]["ar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)") with gr.Row(): layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") - layout["inference_stt"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.") + if cfg.experimental: + layout["inference_stt"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.") layout["inference_stt"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en") with gr.Tab("Sampler Settings"): with gr.Row():