exposed rolling resp context to the web UI, added passing in language to inferencing command line

This commit is contained in:
mrq 2023-10-12 23:21:01 -05:00
parent 298fd9a5f9
commit fb467b19ba
4 changed files with 22 additions and 6 deletions

View File

@ -121,12 +121,14 @@ This will export the latest checkpoints, for example, under `./data/ckpt/ar-retn
To synthesize speech, invoke either (if exported the models): `python -m vall_e <text> <ref_path> <out_path> --ar-ckpt ./models/ar.pt --nar-ckpt ./models/nar.pt` or `python -m vall_e <text> <ref_path> <out_path> yaml=<yaml_path>`
Some additional flags you can pass are:
* `--language`: specifies the language for phonemizing the text, and helps guide inferencing when the model is trained against that language.
* `--max-ar-steps`: maximum steps for inferencing through the AR model. Each second is 75 steps.
* `--device`: device to use (default: `cuda`, examples: `cuda:0`, `cuda:1`, `cpu`)
* `--ar-temp`: sampling temperature to use for the AR pass. During experimentation, `0.95` provides the most consistent output, but values close to it works fine.
* `--nar-temp`: sampling temperature to use for the NAR pass. During experimentation, `0.2` provides clean output, but values upward of `0.6` seems fine too.
And some experimental sampling flags you can use too (your mileage will ***definitely*** vary):
* `--max-ar-context`: Number of `resp` tokens to keep in the context when inferencing. This is akin to "rolling context" in an effort to try and curb any context limitations, but currently does not seem fruitful.
* `--min-ar-temp` / `--min-nar-temp`: triggers the dynamic temperature pathway, adjusting the temperature based on the confidence of the best token. Acceptable values are between `[0.0, (n)ar-temp)`.
+ This simply uplifts the [original implementation](https://github.com/kalomaze/koboldcpp/blob/dynamic-temp/llama.cpp#L5132) to perform it.
+ **!**NOTE**!**: This does not seem to resolve any issues with setting too high/low of a temperature. The right values are yet to be found.
@ -146,8 +148,8 @@ And some experimental sampling flags you can use too (your mileage will ***defin
* train and release a ***good*** model.
* clean up the README, and document, document, document onto the wiki.
* extend to multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)).
* improve throughput:
* extend to ~~multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and~~ addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)).
* improve throughput (despite peaking at 120it/s):
- properly utilize RetNet's recurrent forward / chunkwise forward passes
- utilize an approach similar to [FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa/) with additional heads for decoding N+1, N+2, N+3 AR tokens
+ this requires a properly trained AR, however.

View File

@ -9,6 +9,7 @@ def main():
parser = argparse.ArgumentParser("VALL-E TTS")
parser.add_argument("text")
parser.add_argument("references", type=path_list)
parser.add_argument("--language", type=str, default="en")
parser.add_argument("--out-path", type=Path, default=None)
parser.add_argument("--yaml", type=Path, default=None)
@ -44,6 +45,7 @@ def main():
tts.inference(
text=args.text,
references=args.references,
language=args.language,
out_path=args.out_path,
input_prompt_length=args.input_prompt_length,
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,

View File

@ -13,7 +13,7 @@ from .utils import to_device
from .config import cfg
from .models import get_models
from .engines import load_engines, deepspeed_available
from .data import get_phone_symmap, _load_quants, _cleanup_phones
from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones
if deepspeed_available:
import deepspeed
@ -127,6 +127,13 @@ class TTS():
phones = [ " " if not p else p for p in content ]
return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ])
def encode_lang( self, language ):
symmap = get_lang_symmap()
id = 0
if language in symmap:
id = symmap[language]
return torch.tensor([ id ])
def encode_audio( self, paths, trim_length=0.0 ):
# already a tensor, return it
if isinstance( paths, Tensor ):
@ -149,6 +156,7 @@ class TTS():
self,
text,
references,
language="en",
max_ar_steps=6 * 75,
max_ar_context=-1,
max_nar_levels=7,
@ -171,14 +179,16 @@ class TTS():
out_path = f"./data/{cfg.start_time}.wav"
prom = self.encode_audio( references, trim_length=input_prompt_length )
phns = self.encode_text( text )
phns = self.encode_text( text, language=language )
lang = self.encode_lang( language )
prom = to_device(prom, self.device).to(torch.int16)
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
lang = to_device(lang, self.device).to(torch.uint8)
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
resps_list = self.ar(
text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, max_resp_context=max_ar_context,
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context,
sampling_temperature=ar_temp,
sampling_min_temperature=min_ar_temp,
sampling_top_p=top_p, sampling_top_k=top_k,
@ -190,7 +200,7 @@ class TTS():
)
resps_list = [r.unsqueeze(-1) for r in resps_list]
resps_list = self.nar(
text_list=[phns], proms_list=[prom], resps_list=resps_list,
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
max_levels=max_nar_levels,
sampling_temperature=nar_temp,
sampling_min_temperature=min_nar_temp,

View File

@ -77,6 +77,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--references", type=str, default=kwargs["reference"])
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*75))
parser.add_argument("--max-ar-context", type=int, default=int(kwargs["max-seconds-context"]*75))
parser.add_argument("--max-nar-levels", type=int, default=kwargs["max-nar-levels"])
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"])
@ -200,6 +201,7 @@ with ui:
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
layout["inference"]["inputs"]["max-seconds-context"] = gr.Slider(value=0.0, minimum=0.0, maximum=12.0, step=0.05, label="Context Length", info="Amount of generated audio to keep in the context during inference, in seconds. Set 0 to disable.")
with gr.Row():
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.")
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR.")