added option to use raw text rather than the IPA phonemes (it requires a model trained on raw text)

This commit is contained in:
mrq 2025-01-06 00:10:43 -06:00
parent 3ab11bdc7b
commit 9fa87c417a
4 changed files with 29 additions and 9 deletions

View File

@ -129,13 +129,15 @@ Other solutions will rely on conditioning latents or extracted features as the i
Classifiers are the final output head / projection layer that processes the last hidden states of a model into a probability distribution for each token.
Out of paranoia, each head is split for each macro-task (RVQ level, `stt`, and `len`), even though the core half of the model's training was with a single output head.
* It also helps with not needing to do some tricks by setting unwanted tokens to `-inf`.
### Text Embeddings
The input text phonemes (or output for STT) are passed through an embedding head (`text`), similar to how a normal text LLM would. Nothing fancy is required, as it's very straightforward.
Technically, due to how the audio embeddings are implemented, it's possible to offer "language specific" text embeddings, rather than one unified IPA-based embedding + a language embedding (`lang`).
* Such an implementation *could* in fact inference from normal text rather than IPA phonemes, as language-specific pure text embeddings can be trained.
* Such an implementation can instead inference from normal text rather than IPA phonemes, as language-specific pure text embeddings can be trained.
* This is because some arbitrary first `n` layers of the model *might* instead handle encoding the input prompt embeddings. It's easy to take an existing model and train it on raw text tokens alongside the IPA phonemes as an input.
These embeddings *could* instead be added on top of the input prompt embedding instead of serving as additional tasks (similar to injecting position embeddings), but additional experimentation is required to see if the model both can work under this and/or benefits from this.
@ -280,6 +282,14 @@ The primary benefit of this task is to provide a fast way to directly transcribe
This task will follow a reverse sequence of `<audio><language><RVQ level><output>`.
#### Phonemize / Un-Phonemize
The `phn` task phonemizes raw text and outputs the corresponding IPA phonemes.
The `un-phn` task does the opposite: it'll take IPA phonemes and outputs the text that would phonemize into it.
Currently, `phn` works *okay*, while `un-phn` does not work at all.
## Emergent Behavior
The model can be prompted in creative ways to yield some interesting behaviors:

View File

@ -20,6 +20,7 @@ def main():
parser.add_argument("--split-text-by", type=str, default="\n")
parser.add_argument("--context-history", type=int, default=0)
parser.add_argument("--no-phonemize", action='store_true')
parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--model", type=Path, default=None)
@ -87,6 +88,7 @@ def main():
sampling_kwargs = dict(
split_text_by=args.split_text_by,
context_history=args.context_history,
phonemize=not args.no_phonemize,
max_steps=args.max_steps,
max_levels=args.max_levels,
max_duration=args.max_duration,

View File

@ -98,7 +98,7 @@ class TTS():
def disable_lora( self ):
return self.enable_lora( enabled=False )
def encode_text( self, text, language="auto", precheck=True ):
def encode_text( self, text, language="auto", precheck=True, phonemize=True ):
# already a tensor, return it
if isinstance( text, Tensor ):
return text
@ -109,10 +109,10 @@ class TTS():
if self.symmap["<unk>"] not in tokens:
return torch.tensor( tokens )
content = g2p.encode(text, language=language)
tokens = tokenize( content )
if not phonemize:
return torch.tensor( text_tokenize( content ) )
return torch.tensor( tokens )
return torch.tensor( tokenize( g2p.encode(text, language=language) ) )
def encode_lang( self, language ):
symmap = get_lang_symmap()
@ -361,6 +361,7 @@ class TTS():
use_lora = sampling_kwargs.pop("use_lora", None)
dtype = sampling_kwargs.pop("dtype", self.dtype)
amp = sampling_kwargs.pop("amp", self.amp)
phonemize = sampling_kwargs.pop("phonemize", True)
duration_padding = sampling_kwargs.pop("duration_padding", 1.05)
voice_convert = sampling_kwargs.pop("voice_convert", None)
@ -431,10 +432,10 @@ class TTS():
model = model_ar if model_ar is not None else model_nar
if task == "phn":
text_list = None
raw_text_list = [ torch.tensor( text_tokenize( text ), device=self.device, dtype=torch.int16) ]
raw_text_list = [ self.encode_text( text, phonemize=False ).to(device=self.device, dtype=torch.int16) ]
output_tokenizer = cfg.tokenizer
else:
text_list = [ torch.tensor( tokenize( text ), device=self.device, dtype=torch.int16) ]
text_list = [ self.encode_text( text ).to(device=self.device, dtype=torch.int16) ]
raw_text_list = None
output_tokenizer = cfg.text_tokenizer
@ -489,12 +490,13 @@ class TTS():
if auto_text_lang:
text_language = deduced_language
phns = self.encode_text( line, language=text_language )
phns = self.encode_text( line, language=text_language, phonemize=phonemize )
phns = to_device(phns, device=self.device, dtype=torch.uint8 if len(self.symmap) < 256 else torch.int16)
with torch.autocast(self.device, dtype=dtype, enabled=amp):
input_kwargs = dict(
text_list=[phns],
text_list=[phns] if phonemize else None,
raw_text_list=[phns] if not phonemize else None,
proms_list=[prom],
lang_list=[lang],
disable_tqdm=not use_tqdm,

View File

@ -217,6 +217,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--voice-convert", type=str, default=kwargs["voice-convert"])
parser.add_argument("--language", type=str, default=kwargs["language"])
parser.add_argument("--text-language", type=str, default=kwargs["text-language"])
parser.add_argument("--no-phonemize", action="store_true")
parser.add_argument("--split-text-by", type=str, default=kwargs["split-text-by"])
parser.add_argument("--context-history", type=int, default=kwargs["context-history"])
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
@ -272,6 +273,9 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
if kwargs.pop("refine-on-stop", False):
args.refine_on_stop = True
if kwargs.pop("no-phonemize", False):
args.no_phonemize = False
if args.split_text_by == "lines":
args.split_text_by = "\n"
elif args.split_text_by == "none":
@ -287,6 +291,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
sampling_kwargs = dict(
split_text_by=args.split_text_by,
context_history=args.context_history,
phonemize=not args.no_phonemize,
voice_convert=args.voice_convert,
max_steps=args.max_steps,
max_levels=args.max_levels,
@ -467,6 +472,7 @@ with ui:
with gr.Row():
layout["inference_tts"]["inputs"]["split-text-by"] = gr.Dropdown(choices=["sentences", "lines"], label="Text Delimiter", info="How to split the text into utterances.", value="sentences")
layout["inference_tts"]["inputs"]["context-history"] = gr.Slider(value=0, minimum=0, maximum=4, step=1, label="(Rolling) Context History", info="How many prior lines to serve as the context/prefix (0 to disable).")
layout["inference_tts"]["inputs"]["no-phonemize"] = gr.Checkbox(label="No Phonemize", info="Use raw text rather than phonemize the text as the input prompt.")
with gr.Tab("Sampler Settings"):
with gr.Row():
layout["inference_tts"]["inputs"]["ar-temperature"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR/NAR-len)", info="Adjusts the probabilities in the AR/NAR-len. (0 to greedy* sample)")