diff --git a/README.md b/README.md index 7a7a3d0..adf9d71 100755 --- a/README.md +++ b/README.md @@ -121,12 +121,14 @@ This will export the latest checkpoints, for example, under `./data/ckpt/ar-retn To synthesize speech, invoke either (if exported the models): `python -m vall_e --ar-ckpt ./models/ar.pt --nar-ckpt ./models/nar.pt` or `python -m vall_e yaml=` Some additional flags you can pass are: +* `--language`: specifies the language for phonemizing the text, and helps guide inferencing when the model is trained against that language. * `--max-ar-steps`: maximum steps for inferencing through the AR model. Each second is 75 steps. * `--device`: device to use (default: `cuda`, examples: `cuda:0`, `cuda:1`, `cpu`) * `--ar-temp`: sampling temperature to use for the AR pass. During experimentation, `0.95` provides the most consistent output, but values close to it works fine. * `--nar-temp`: sampling temperature to use for the NAR pass. During experimentation, `0.2` provides clean output, but values upward of `0.6` seems fine too. And some experimental sampling flags you can use too (your mileage will ***definitely*** vary): +* `--max-ar-context`: Number of `resp` tokens to keep in the context when inferencing. This is akin to "rolling context" in an effort to try and curb any context limitations, but currently does not seem fruitful. * `--min-ar-temp` / `--min-nar-temp`: triggers the dynamic temperature pathway, adjusting the temperature based on the confidence of the best token. Acceptable values are between `[0.0, (n)ar-temp)`. + This simply uplifts the [original implementation](https://github.com/kalomaze/koboldcpp/blob/dynamic-temp/llama.cpp#L5132) to perform it. + **!**NOTE**!**: This does not seem to resolve any issues with setting too high/low of a temperature. The right values are yet to be found. @@ -146,8 +148,8 @@ And some experimental sampling flags you can use too (your mileage will ***defin * train and release a ***good*** model. * clean up the README, and document, document, document onto the wiki. -* extend to multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)). -* improve throughput: +* extend to ~~multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and~~ addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)). +* improve throughput (despite peaking at 120it/s): - properly utilize RetNet's recurrent forward / chunkwise forward passes - utilize an approach similar to [FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa/) with additional heads for decoding N+1, N+2, N+3 AR tokens + this requires a properly trained AR, however. diff --git a/vall_e/__main__.py b/vall_e/__main__.py index 9122cb9..27faaef 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -9,6 +9,7 @@ def main(): parser = argparse.ArgumentParser("VALL-E TTS") parser.add_argument("text") parser.add_argument("references", type=path_list) + parser.add_argument("--language", type=str, default="en") parser.add_argument("--out-path", type=Path, default=None) parser.add_argument("--yaml", type=Path, default=None) @@ -44,6 +45,7 @@ def main(): tts.inference( text=args.text, references=args.references, + language=args.language, out_path=args.out_path, input_prompt_length=args.input_prompt_length, max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels, diff --git a/vall_e/inference.py b/vall_e/inference.py index 12e5672..a0cc0eb 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -13,7 +13,7 @@ from .utils import to_device from .config import cfg from .models import get_models from .engines import load_engines, deepspeed_available -from .data import get_phone_symmap, _load_quants, _cleanup_phones +from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones if deepspeed_available: import deepspeed @@ -127,6 +127,13 @@ class TTS(): phones = [ " " if not p else p for p in content ] return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ]) + def encode_lang( self, language ): + symmap = get_lang_symmap() + id = 0 + if language in symmap: + id = symmap[language] + return torch.tensor([ id ]) + def encode_audio( self, paths, trim_length=0.0 ): # already a tensor, return it if isinstance( paths, Tensor ): @@ -149,6 +156,7 @@ class TTS(): self, text, references, + language="en", max_ar_steps=6 * 75, max_ar_context=-1, max_nar_levels=7, @@ -171,14 +179,16 @@ class TTS(): out_path = f"./data/{cfg.start_time}.wav" prom = self.encode_audio( references, trim_length=input_prompt_length ) - phns = self.encode_text( text ) + phns = self.encode_text( text, language=language ) + lang = self.encode_lang( language ) prom = to_device(prom, self.device).to(torch.int16) phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16) + lang = to_device(lang, self.device).to(torch.uint8) with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp): resps_list = self.ar( - text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, max_resp_context=max_ar_context, + text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context, sampling_temperature=ar_temp, sampling_min_temperature=min_ar_temp, sampling_top_p=top_p, sampling_top_k=top_k, @@ -190,7 +200,7 @@ class TTS(): ) resps_list = [r.unsqueeze(-1) for r in resps_list] resps_list = self.nar( - text_list=[phns], proms_list=[prom], resps_list=resps_list, + text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, max_levels=max_nar_levels, sampling_temperature=nar_temp, sampling_min_temperature=min_nar_temp, diff --git a/vall_e/webui.py b/vall_e/webui.py index 79f1860..481b11e 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -77,6 +77,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser.add_argument("--references", type=str, default=kwargs["reference"]) parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"]) parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*75)) + parser.add_argument("--max-ar-context", type=int, default=int(kwargs["max-seconds-context"]*75)) parser.add_argument("--max-nar-levels", type=int, default=kwargs["max-nar-levels"]) parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"]) parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"]) @@ -200,6 +201,7 @@ with ui: layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.") layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.") layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.") + layout["inference"]["inputs"]["max-seconds-context"] = gr.Slider(value=0.0, minimum=0.0, maximum=12.0, step=0.05, label="Context Length", info="Amount of generated audio to keep in the context during inference, in seconds. Set 0 to disable.") with gr.Row(): layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.") layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR.")