diff --git a/vall_e/__main__.py b/vall_e/__main__.py index d355c5e..6c41fc0 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -50,7 +50,7 @@ def main(): args = parser.parse_args() tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp, attention=args.attention ) - tts.inference( + output = tts.inference( text=args.text, references=args.references, language=args.language, @@ -68,6 +68,9 @@ def main(): dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length, seed=args.seed, ) + + if isinstance( output, str ): + print( output ) if __name__ == "__main__": main() diff --git a/vall_e/config.py b/vall_e/config.py index afc4f73..5708fe7 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -209,6 +209,9 @@ class ModelExperimentalSettings: interleave: bool = False # use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal) split_classifiers: bool = False # each RVQ level gets its own classifier / output proj / LM head rather than sharing one for all RVQ levels (to-do: also split for text/prom) audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level + # a model trained not summing audio embeddings *can* have this enabled without any apparent issues + # a model trained to sum *cannot* have this disabled without any apparent issues, or at least the ar+nar-retnet-8 can't. + # in theory a model that is trained to sum embeddings can peform better due to "seeing" previous levles (due to the R in RVQ standing for residuals...), but in practice it seems fine to not do so audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings kv_heads: int = 0 # MHA or GQA (for supported backends) p_rvq_levels: str | list = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely @@ -225,6 +228,7 @@ class ModelExperimentalSettings: # VALL-E 2's approach of "combining token embeddings to group them" sounds terribad for a shared AR/NAR model # however, introducing partial parallel decoding for the AR maybe maybe MAYBE might help try and unify the AR/NAR tasks better, MAYBE # it just seems like a bitch to try and train something worthwhile with it, since there's crackles every other token + # RetNet's chunked inferencing might be a better place for this p_len_train: float = 0.05 # odds of injecting a "len" task within the model for NAR-len # to-to: just incorporate this as a task instead diff --git a/vall_e/inference.py b/vall_e/inference.py index 7663d29..66ef906 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -211,7 +211,6 @@ class TTS(): raise Exception("!") text_list = [ cfg.tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in text_list ] - print( text_list ) return text_list[0] diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 533c678..620b349 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1494,8 +1494,8 @@ class Base(nn.Module): return [ logit.argmax(dim=1) for logit in logits ] # perform repetition penalizing - if "len" not in self.capabilities and repetition_penalty != 1.0: - logits = [ reptition_penalize(logit, previous=resps[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, prev_list ) ] + if "len" not in self.capabilities: + logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ] # (AR) perform length penalizing if quant_levels is None and self.causal: diff --git a/vall_e/webui.py b/vall_e/webui.py index dc7ea97..ac3a391 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -396,7 +396,7 @@ with ui: layout["inference_stt"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.") layout["inference_stt"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.") with gr.Row(): - layout["inference_stt"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.") + layout["inference_stt"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.25, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.") layout["inference_stt"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.") layout["inference_stt"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.") with gr.Row():