exposed rolling resp context to the web UI, added passing in language to inferencing command line
This commit is contained in:
parent
298fd9a5f9
commit
fb467b19ba
|
@ -121,12 +121,14 @@ This will export the latest checkpoints, for example, under `./data/ckpt/ar-retn
|
|||
To synthesize speech, invoke either (if exported the models): `python -m vall_e <text> <ref_path> <out_path> --ar-ckpt ./models/ar.pt --nar-ckpt ./models/nar.pt` or `python -m vall_e <text> <ref_path> <out_path> yaml=<yaml_path>`
|
||||
|
||||
Some additional flags you can pass are:
|
||||
* `--language`: specifies the language for phonemizing the text, and helps guide inferencing when the model is trained against that language.
|
||||
* `--max-ar-steps`: maximum steps for inferencing through the AR model. Each second is 75 steps.
|
||||
* `--device`: device to use (default: `cuda`, examples: `cuda:0`, `cuda:1`, `cpu`)
|
||||
* `--ar-temp`: sampling temperature to use for the AR pass. During experimentation, `0.95` provides the most consistent output, but values close to it works fine.
|
||||
* `--nar-temp`: sampling temperature to use for the NAR pass. During experimentation, `0.2` provides clean output, but values upward of `0.6` seems fine too.
|
||||
|
||||
And some experimental sampling flags you can use too (your mileage will ***definitely*** vary):
|
||||
* `--max-ar-context`: Number of `resp` tokens to keep in the context when inferencing. This is akin to "rolling context" in an effort to try and curb any context limitations, but currently does not seem fruitful.
|
||||
* `--min-ar-temp` / `--min-nar-temp`: triggers the dynamic temperature pathway, adjusting the temperature based on the confidence of the best token. Acceptable values are between `[0.0, (n)ar-temp)`.
|
||||
+ This simply uplifts the [original implementation](https://github.com/kalomaze/koboldcpp/blob/dynamic-temp/llama.cpp#L5132) to perform it.
|
||||
+ **!**NOTE**!**: This does not seem to resolve any issues with setting too high/low of a temperature. The right values are yet to be found.
|
||||
|
@ -146,8 +148,8 @@ And some experimental sampling flags you can use too (your mileage will ***defin
|
|||
|
||||
* train and release a ***good*** model.
|
||||
* clean up the README, and document, document, document onto the wiki.
|
||||
* extend to multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)).
|
||||
* improve throughput:
|
||||
* extend to ~~multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and~~ addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)).
|
||||
* improve throughput (despite peaking at 120it/s):
|
||||
- properly utilize RetNet's recurrent forward / chunkwise forward passes
|
||||
- utilize an approach similar to [FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa/) with additional heads for decoding N+1, N+2, N+3 AR tokens
|
||||
+ this requires a properly trained AR, however.
|
||||
|
|
|
@ -9,6 +9,7 @@ def main():
|
|||
parser = argparse.ArgumentParser("VALL-E TTS")
|
||||
parser.add_argument("text")
|
||||
parser.add_argument("references", type=path_list)
|
||||
parser.add_argument("--language", type=str, default="en")
|
||||
parser.add_argument("--out-path", type=Path, default=None)
|
||||
|
||||
parser.add_argument("--yaml", type=Path, default=None)
|
||||
|
@ -44,6 +45,7 @@ def main():
|
|||
tts.inference(
|
||||
text=args.text,
|
||||
references=args.references,
|
||||
language=args.language,
|
||||
out_path=args.out_path,
|
||||
input_prompt_length=args.input_prompt_length,
|
||||
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
||||
|
|
|
@ -13,7 +13,7 @@ from .utils import to_device
|
|||
from .config import cfg
|
||||
from .models import get_models
|
||||
from .engines import load_engines, deepspeed_available
|
||||
from .data import get_phone_symmap, _load_quants, _cleanup_phones
|
||||
from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones
|
||||
|
||||
if deepspeed_available:
|
||||
import deepspeed
|
||||
|
@ -127,6 +127,13 @@ class TTS():
|
|||
phones = [ " " if not p else p for p in content ]
|
||||
return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ])
|
||||
|
||||
def encode_lang( self, language ):
|
||||
symmap = get_lang_symmap()
|
||||
id = 0
|
||||
if language in symmap:
|
||||
id = symmap[language]
|
||||
return torch.tensor([ id ])
|
||||
|
||||
def encode_audio( self, paths, trim_length=0.0 ):
|
||||
# already a tensor, return it
|
||||
if isinstance( paths, Tensor ):
|
||||
|
@ -149,6 +156,7 @@ class TTS():
|
|||
self,
|
||||
text,
|
||||
references,
|
||||
language="en",
|
||||
max_ar_steps=6 * 75,
|
||||
max_ar_context=-1,
|
||||
max_nar_levels=7,
|
||||
|
@ -171,14 +179,16 @@ class TTS():
|
|||
out_path = f"./data/{cfg.start_time}.wav"
|
||||
|
||||
prom = self.encode_audio( references, trim_length=input_prompt_length )
|
||||
phns = self.encode_text( text )
|
||||
phns = self.encode_text( text, language=language )
|
||||
lang = self.encode_lang( language )
|
||||
|
||||
prom = to_device(prom, self.device).to(torch.int16)
|
||||
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
||||
lang = to_device(lang, self.device).to(torch.uint8)
|
||||
|
||||
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
||||
resps_list = self.ar(
|
||||
text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, max_resp_context=max_ar_context,
|
||||
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context,
|
||||
sampling_temperature=ar_temp,
|
||||
sampling_min_temperature=min_ar_temp,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||
|
@ -190,7 +200,7 @@ class TTS():
|
|||
)
|
||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
||||
resps_list = self.nar(
|
||||
text_list=[phns], proms_list=[prom], resps_list=resps_list,
|
||||
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
|
||||
max_levels=max_nar_levels,
|
||||
sampling_temperature=nar_temp,
|
||||
sampling_min_temperature=min_nar_temp,
|
||||
|
|
|
@ -77,6 +77,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
||||
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
|
||||
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*75))
|
||||
parser.add_argument("--max-ar-context", type=int, default=int(kwargs["max-seconds-context"]*75))
|
||||
parser.add_argument("--max-nar-levels", type=int, default=kwargs["max-nar-levels"])
|
||||
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
|
||||
parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"])
|
||||
|
@ -200,6 +201,7 @@ with ui:
|
|||
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
|
||||
layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
||||
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
|
||||
layout["inference"]["inputs"]["max-seconds-context"] = gr.Slider(value=0.0, minimum=0.0, maximum=12.0, step=0.05, label="Context Length", info="Amount of generated audio to keep in the context during inference, in seconds. Set 0 to disable.")
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.")
|
||||
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR.")
|
||||
|
|
Loading…
Reference in New Issue
Block a user