README tweaks, added --input-prompt-prefix as an experiment (its literally better to just not do this, but i'll retain it in case i have a revelation on how to improve it)

This commit is contained in:
mrq 2024-10-04 18:57:19 -05:00
parent a9fa0898a9
commit 4a8e3ccf06
5 changed files with 32 additions and 11 deletions

View File

@ -64,7 +64,10 @@ If you already have a dataset you want, for example, your own large corpus or fo
3. Run `python3 -m vall_e.emb.process`. This will phonemize the transcriptions and quantize the audio.
+ If you're using a Descript-Audio-Codec based model, ensure to set the sample rate and audio backend accordingly.
4. Copy `./data/config.yaml` to `./training/config.yaml`. Customize the training configuration and populate your `dataset.training` list with the values stored under `./training/dataset/list.json`.
4. Run `python3 -m vall_e.emb.similar`. This will calculate the top-k most similar utterances for each utterance for use with sampling.
+ Doing this will help the model follow the input prompt stronger, at the possible "cost" of the model not learning how to "infer" the target speaker AND prosidy.
5. Copy `./data/config.yaml` to `./training/config.yaml`. Customize the training configuration and populate your `dataset.training` list with the values stored under `./training/dataset/list.json`.
+ Refer to `./vall_e/config.py` for additional configuration details.
### Dataset Formats
@ -88,7 +91,7 @@ For multiple GPUs, or exotic distributed training:
You can enter `save` to save the state at any time, or `quit` to save and quit training.
The `lr` will also let you adjust the learning rate on the fly. For example: `lr 1.0e-3` will set the learning rate to `0.001`.
The `lr` command will also let you adjust the learning rate on the fly. For example: `lr 1.0e-3` will set the learning rate to `0.001`.
### Finetuning
@ -205,8 +208,12 @@ Some additional flags you can pass are:
* `--device`: device to use (default: `cuda`, examples: `cuda:0`, `cuda:1`, `cpu`)
* `--ar-temp`: sampling temperature to use for the AR pass. During experimentation, `0.95` provides the most consistent output, but values close to it works fine.
* `--nar-temp`: sampling temperature to use for the NAR pass. During experimentation, the lower value, the better. Set to `0` to enable greedy sampling.
* `--input-prompt-length`: the maximum duration the input prompt can be (~6 seconds is fine, longer durations lead to slower generations for "better" accuracy, as long as the model was trained against such input prompt durations)
And some experimental sampling flags you can use too (your mileage will ***definitely*** vary, but most of these are bandaids for a bad AR):
* `--input-prompt-prefix`: (AR only) treats the input prompt as the initial response prefix, but...
* the transcription of the prompt needs to be in the input text prompt.
* doesn't perform all that well (I belive the model needs to be trained a bit on this, as `tts-c`).
* `--min-ar-temp`: triggers the dynamic temperature pathway, adjusting the temperature based on the confidence of the best token. Acceptable values are between `[0.0, (n)ar-temp)`.
+ This simply uplifts the [original implementation](https://github.com/kalomaze/koboldcpp/blob/dynamic-temp/llama.cpp#L5132) to perform it.
+ **!**NOTE**!**: This does not seem to resolve any issues with setting too high/low of a temperature. The right values are yet to be found.
@ -273,8 +280,9 @@ So far, this only allows you to load a different model without needing to restar
- ~~this might need a better training paradigm with providing similar enough input prompts to a given output response.~~
- this might have just needed a better dataset + a better input prompt "sampling" method
* [ ] well-integrated training through the Web UI (without the kludge from ai-voice-cloning)
* [x] ~~explore alternative setups, like a NAR-only model~~
* [x] ~~explore alternative setups, like a NAR-only model or Descript-Audio-Codec~~
- the current experiment of an AR length-predictor + NAR for the rest seems to fall apart...
- Descript-Audio-Codec 44KHz has NAR issues, but this *might* be user error.
* [x] ~~explore better sampling techniques~~
- the AR doesn't *need* exotic sampling techniques, as they're bandaids for a bad AR.
- the NAR benefits from greedy sampling, and anything else just harms output quality.
@ -290,6 +298,9 @@ So far, this only allows you to load a different model without needing to restar
- this *technically* can work without any additional architecture changes, just clever tricks with sampling-then-decoding-to-audio.
- something similar to HiFiGAN (or the one for TorToiSe) trained on the last hidden states of the AR *might* also enable an alternate way for streaming.
* [ ] replace the phonemizer with something that doesn't depend on espeak
* [ ] train the model to handle text => phoneme (without a hit to the rest of the model)
* [ ] ...and phonemes => text
* [ ] allow raw text as input instead
- espeak is nice, but I can only really put my whole trust with phonemizing English.
- a small model trained to handle converting text to phonemes might work, but has it's own problems (another model to carry around, as accurate as the dataset it was trained against, requires training for each language... etc).
* [ ] explore exotic features like:

View File

@ -26,6 +26,7 @@ def main():
parser.add_argument("--min-ar-temp", type=float, default=-1.0)
parser.add_argument("--min-nar-temp", type=float, default=-1.0)
parser.add_argument("--input-prompt-length", type=float, default=3.0)
parser.add_argument("--input-prompt-prefix", action="store_true")
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--top-k", type=int, default=0)
@ -57,6 +58,7 @@ def main():
task=args.task,
out_path=args.out_path,
input_prompt_length=args.input_prompt_length,
input_prompt_prefix=args.input_prompt_prefix,
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,

View File

@ -174,6 +174,7 @@ class TTS():
max_nar_levels=7,
#
input_prompt_length=0.0,
input_prompt_prefix=False,
#
ar_temp=0.95,
nar_temp=0.5,
@ -275,6 +276,7 @@ class TTS():
if model_ar is not None:
resps_list = model_ar(
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps,
input_prompt_prefix=input_prompt_prefix,
sampling_temperature=ar_temp,
sampling_min_temperature=min_ar_temp,
sampling_top_p=top_p, sampling_top_k=top_k,
@ -291,6 +293,7 @@ class TTS():
)
resps_list = model_nar(
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
input_prompt_prefix=input_prompt_prefix,
max_levels=max_nar_levels,
sampling_temperature=nar_temp,
sampling_min_temperature=min_nar_temp,

View File

@ -45,6 +45,8 @@ class AR_NAR(Base):
max_steps: int = 1000,
max_levels: int = 0,
input_prompt_prefix: bool = False,
sampling_temperature: float = 1.0,
sampling_min_temperature: float = -1.0,
sampling_top_k: int = -100,
@ -245,6 +247,7 @@ class AR_NAR(Base):
enable_lora( self, cfg.lora.active_level( 0 ) )
# STT
start_slice = [ 0 for _ in range(batch_size) ]
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
stopped = torch.zeros(batch_size, device=device).bool()
@ -258,10 +261,15 @@ class AR_NAR(Base):
scores = [ 1.0 ] * sampling_beam_width
# add <bos> to text for STT
for i, sequence in enumerate( sequence_list ):
# add <bos> to text for STT
if task_list[i] in text_task:
start_slice[i] = 1
sequence_list[i] = torch.cat([sequence_list[i], torch.tensor([1], dtype=torch.int16, device=device)])
# treat input prompt as initial resp (by prefixing with the prompt instead)
elif input_prompt_prefix:
start_slice[i] = proms_list[i].shape[0]
sequence_list[i], proms_list[i] = proms_list[i][:, 0], sequence_list[i]
# get next in sequence
for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm):
@ -269,12 +277,6 @@ class AR_NAR(Base):
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
"""
print( "task_list:", task_list )
print( "text_list:", text_list )
print( "resps_list:", resps_list )
"""
inputs = self.inputs(
text_list=text_list,
proms_list=proms_list,
@ -357,7 +359,7 @@ class AR_NAR(Base):
# remove stop token
sequence_list = [self._prune(r, audio_stop_token if task_list[i] not in text_task else text_stop_token) for i, r in enumerate(sequence_list)]
# remove <bos>
sequence_list = [ sequence_list[i] if task not in text_task else sequence_list[i][1:] for i, task in enumerate( task_list ) ]
sequence_list = [ sequence_list[i][start_slice[i]:] for i, task in enumerate( task_list ) ]
return sequence_list

View File

@ -164,6 +164,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--references", type=str, default=kwargs["reference"])
parser.add_argument("--language", type=str, default=kwargs["language"])
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"])
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second))
parser.add_argument("--max-nar-levels", type=int, default=0), # kwargs["max-nar-levels"])
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
@ -203,6 +204,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
max_ar_steps=args.max_ar_steps,
max_nar_levels=args.max_nar_levels,
input_prompt_length=args.input_prompt_length,
input_prompt_prefix=args.input_prompt_prefix,
ar_temp=args.ar_temp,
nar_temp=args.nar_temp,
min_ar_temp=args.min_ar_temp,
@ -384,6 +386,7 @@ with ui:
layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
with gr.Row():
#layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
with gr.Tab("Sampler Settings"):