validated rep pen for STT (sometimes needed to wrangle the model)

This commit is contained in:
mrq 2024-09-08 08:30:30 -05:00
parent 6a967f91b9
commit 54203c059d
5 changed files with 11 additions and 5 deletions

View File

@ -50,7 +50,7 @@ def main():
args = parser.parse_args()
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp, attention=args.attention )
tts.inference(
output = tts.inference(
text=args.text,
references=args.references,
language=args.language,
@ -69,5 +69,8 @@ def main():
seed=args.seed,
)
if isinstance( output, str ):
print( output )
if __name__ == "__main__":
main()

View File

@ -209,6 +209,9 @@ class ModelExperimentalSettings:
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal)
split_classifiers: bool = False # each RVQ level gets its own classifier / output proj / LM head rather than sharing one for all RVQ levels (to-do: also split for text/prom)
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
# a model trained not summing audio embeddings *can* have this enabled without any apparent issues
# a model trained to sum *cannot* have this disabled without any apparent issues, or at least the ar+nar-retnet-8 can't.
# in theory a model that is trained to sum embeddings can peform better due to "seeing" previous levles (due to the R in RVQ standing for residuals...), but in practice it seems fine to not do so
audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
kv_heads: int = 0 # MHA or GQA (for supported backends)
p_rvq_levels: str | list = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
@ -225,6 +228,7 @@ class ModelExperimentalSettings:
# VALL-E 2's approach of "combining token embeddings to group them" sounds terribad for a shared AR/NAR model
# however, introducing partial parallel decoding for the AR maybe maybe MAYBE might help try and unify the AR/NAR tasks better, MAYBE
# it just seems like a bitch to try and train something worthwhile with it, since there's crackles every other token
# RetNet's chunked inferencing might be a better place for this
p_len_train: float = 0.05 # odds of injecting a "len" task within the model for NAR-len
# to-to: just incorporate this as a task instead

View File

@ -211,7 +211,6 @@ class TTS():
raise Exception("!")
text_list = [ cfg.tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in text_list ]
print( text_list )
return text_list[0]

View File

@ -1494,8 +1494,8 @@ class Base(nn.Module):
return [ logit.argmax(dim=1) for logit in logits ]
# perform repetition penalizing
if "len" not in self.capabilities and repetition_penalty != 1.0:
logits = [ reptition_penalize(logit, previous=resps[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, prev_list ) ]
if "len" not in self.capabilities:
logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
# (AR) perform length penalizing
if quant_levels is None and self.causal:

View File

@ -396,7 +396,7 @@ with ui:
layout["inference_stt"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
layout["inference_stt"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
with gr.Row():
layout["inference_stt"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
layout["inference_stt"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.25, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
layout["inference_stt"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
layout["inference_stt"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
with gr.Row():