validated rep pen for STT (sometimes needed to wrangle the model)
This commit is contained in:
parent
6a967f91b9
commit
54203c059d
|
@ -50,7 +50,7 @@ def main():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp, attention=args.attention )
|
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp, attention=args.attention )
|
||||||
tts.inference(
|
output = tts.inference(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
references=args.references,
|
references=args.references,
|
||||||
language=args.language,
|
language=args.language,
|
||||||
|
@ -68,6 +68,9 @@ def main():
|
||||||
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance( output, str ):
|
||||||
|
print( output )
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -209,6 +209,9 @@ class ModelExperimentalSettings:
|
||||||
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal)
|
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal)
|
||||||
split_classifiers: bool = False # each RVQ level gets its own classifier / output proj / LM head rather than sharing one for all RVQ levels (to-do: also split for text/prom)
|
split_classifiers: bool = False # each RVQ level gets its own classifier / output proj / LM head rather than sharing one for all RVQ levels (to-do: also split for text/prom)
|
||||||
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
|
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
|
||||||
|
# a model trained not summing audio embeddings *can* have this enabled without any apparent issues
|
||||||
|
# a model trained to sum *cannot* have this disabled without any apparent issues, or at least the ar+nar-retnet-8 can't.
|
||||||
|
# in theory a model that is trained to sum embeddings can peform better due to "seeing" previous levles (due to the R in RVQ standing for residuals...), but in practice it seems fine to not do so
|
||||||
audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
|
audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
|
||||||
kv_heads: int = 0 # MHA or GQA (for supported backends)
|
kv_heads: int = 0 # MHA or GQA (for supported backends)
|
||||||
p_rvq_levels: str | list = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
|
p_rvq_levels: str | list = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
|
||||||
|
@ -225,6 +228,7 @@ class ModelExperimentalSettings:
|
||||||
# VALL-E 2's approach of "combining token embeddings to group them" sounds terribad for a shared AR/NAR model
|
# VALL-E 2's approach of "combining token embeddings to group them" sounds terribad for a shared AR/NAR model
|
||||||
# however, introducing partial parallel decoding for the AR maybe maybe MAYBE might help try and unify the AR/NAR tasks better, MAYBE
|
# however, introducing partial parallel decoding for the AR maybe maybe MAYBE might help try and unify the AR/NAR tasks better, MAYBE
|
||||||
# it just seems like a bitch to try and train something worthwhile with it, since there's crackles every other token
|
# it just seems like a bitch to try and train something worthwhile with it, since there's crackles every other token
|
||||||
|
# RetNet's chunked inferencing might be a better place for this
|
||||||
|
|
||||||
p_len_train: float = 0.05 # odds of injecting a "len" task within the model for NAR-len
|
p_len_train: float = 0.05 # odds of injecting a "len" task within the model for NAR-len
|
||||||
# to-to: just incorporate this as a task instead
|
# to-to: just incorporate this as a task instead
|
||||||
|
|
|
@ -211,7 +211,6 @@ class TTS():
|
||||||
raise Exception("!")
|
raise Exception("!")
|
||||||
|
|
||||||
text_list = [ cfg.tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in text_list ]
|
text_list = [ cfg.tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in text_list ]
|
||||||
print( text_list )
|
|
||||||
|
|
||||||
return text_list[0]
|
return text_list[0]
|
||||||
|
|
||||||
|
|
|
@ -1494,8 +1494,8 @@ class Base(nn.Module):
|
||||||
return [ logit.argmax(dim=1) for logit in logits ]
|
return [ logit.argmax(dim=1) for logit in logits ]
|
||||||
|
|
||||||
# perform repetition penalizing
|
# perform repetition penalizing
|
||||||
if "len" not in self.capabilities and repetition_penalty != 1.0:
|
if "len" not in self.capabilities:
|
||||||
logits = [ reptition_penalize(logit, previous=resps[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, prev_list ) ]
|
logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
|
||||||
|
|
||||||
# (AR) perform length penalizing
|
# (AR) perform length penalizing
|
||||||
if quant_levels is None and self.causal:
|
if quant_levels is None and self.causal:
|
||||||
|
|
|
@ -396,7 +396,7 @@ with ui:
|
||||||
layout["inference_stt"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
|
layout["inference_stt"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
|
||||||
layout["inference_stt"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
|
layout["inference_stt"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference_stt"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
layout["inference_stt"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.25, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
||||||
layout["inference_stt"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
layout["inference_stt"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
||||||
layout["inference_stt"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
|
layout["inference_stt"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user