diff --git a/vall_e/__main__.py b/vall_e/__main__.py index 6de6dcc..f1fcc8a 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -14,6 +14,7 @@ def main(): parser.add_argument("references", type=path_list, default=None) parser.add_argument("--language", type=str, default="en") parser.add_argument("--task", type=str, default="tts") + parser.add_argument("--modality", type=str, default="auto") parser.add_argument("--out-path", type=Path, default=None) parser.add_argument("--yaml", type=Path, default=None) @@ -108,6 +109,7 @@ def main(): references=args.references, language=args.language, task=args.task, + modality=args.modality, out_path=args.out_path, input_prompt_length=args.input_prompt_length, diff --git a/vall_e/config.py b/vall_e/config.py index dff2690..b7c33c2 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -262,14 +262,15 @@ class ModelExperimentalSettings: masking_train_p: float = 0.0 # odds of training with masking masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on - masking_ratio: str | float = 0.8 # sets a masking ratio, "random" will randomly pick + masking_ratio: str | float = 0.8 # sets a masking ratio, "random" will randomly pick, "rand" will pick between [0.2, 0.8] ignore_inputs_for_loss: bool = True # only calculate the loss on the outputs since thats what matters, as the inputs that do have loss calculated upon affects the loss for the entire sequence - # classifier-free guidance shit + # classifier-free guidance training settings cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training cfg_text_dropout_p: float = 0.0 # 0.0 # probability to drop out input audio prompt during training cfg_prom_dropout_p: float = 0.0 # 0.3 # probability to drop out input audio prompt during training + # failed experiment layerskip: bool = False # layerskip compatible model (or training for) #layerskip_rvq_levels: list = field(default_factory=lambda: []) # RVQ levels to train / inference layerskip for (to-do: implement, see if it matters) layerskip_r: int = 2 # number of layers to factor into early-exit loss calc diff --git a/vall_e/demo.py b/vall_e/demo.py index 61e2da7..041af36 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -57,6 +57,7 @@ def main(): parser.add_argument("--language", type=str, default="en") parser.add_argument("--task", type=str, default="tts") + parser.add_argument("--modality", type=str, default="auto") parser.add_argument("--out-path", type=Path, default=None) parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second) @@ -230,6 +231,8 @@ def main(): html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read() sampling_kwargs = dict( + task=args.task, + modality=args.modality, max_steps=args.max_steps, max_levels=args.max_levels, max_duration=args.max_duration, diff --git a/vall_e/inference.py b/vall_e/inference.py index af943c8..b38eafd 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -179,6 +179,12 @@ class TTS(): sums = False ) for l in range( input.shape[-1] - 1 ) ]) + def modality( self, modality ): + # cringe to handle the best default mode for a given model + if modality == "auto" and cfg.model.name in ["ar+nar", "nar-len"]: + modality = cfg.model.name + return modality + @torch.inference_mode() def inference( self, @@ -186,6 +192,7 @@ class TTS(): references, language="en", task="tts", + modality="auto", input_prompt_length = 0, load_from_artifact = False, @@ -215,6 +222,14 @@ class TTS(): seed = set_seed(seed) + modality = self.modality( modality ) + # force AR+NAR + if modality == "ar+nar": + model_len = None + # force NAR-len + elif modality == "nar-len": + model_ar = None + if task == "stt": resp = self.encode_audio( references ) lang = self.encode_lang( language ) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 3945335..b20567b 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -254,9 +254,11 @@ class AR_NAR(Base): refine_on_stop = sampling_kwargs.get("refine_on_stop", False) entropix_sampling = sampling_kwargs.get("entropix_sampling", False) - temperature = sampling_kwargs.pop("temperature", 1.0) - cfg_strength = sampling_kwargs.get("cfg_strength", 3.0) # this really helps keep audio coherent so far - cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7) + # greedy sampling is very, very much preferred, but using greedy logit scores later helps enough + temperature = sampling_kwargs.pop("temperature", 0.0) + # this really helps keep audio coherent so far + cfg_strength = sampling_kwargs.get("cfg_strength", 2.0) + cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75) start_noise = sampling_kwargs.get("denoise_start", 0.0) end_noise = sampling_kwargs.get("denoise_end", 1.0) max_steps = math.floor(max_steps * (end_noise - start_noise)) @@ -283,7 +285,6 @@ class AR_NAR(Base): annealing = 1.0 - timestep # get noise level, per cosine scheduling noise_p = math.cos( timestep * math.pi * 0.5 ) - #noise_p = annealing # pick the worst scoring tokens to mask off masked_indices = [ score.topk( max(int( noise_p * seq_len ), 1), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ] # mask off inputs @@ -293,7 +294,6 @@ class AR_NAR(Base): # timestep inputs time_list = [ timestep for _ in range(batch_size) ] - # greedy sampling is very, very much preferred, but using greedy logit scores later helps enough sampling_temperature = temperature * annealing sampling_cfg = cfg_strength * timestep @@ -364,7 +364,7 @@ class AR_NAR(Base): 1.0 - # only keep scores of tokens we are predicting (and ignore the tokens previously finalized) torch.where( masked, torch.tensor([score for index, score in enumerate(scores)], device=device), torch.ones(masked.shape, device=device) ) - # use unmodified logit scores for this, as it offers better stability + # use unmodified logit scores for this, as it offers better stability for scores, masked in zip( unfiltered_sampled.scores, is_masked ) ] @@ -395,7 +395,6 @@ class AR_NAR(Base): device = resps_list[0].device batch_size = len(resps_list) - # convert NAR specific args sampling_kwargs = convert_kwargs( sampling_kwargs, "nar_" ) @@ -431,19 +430,6 @@ class AR_NAR(Base): **sampling_kwargs, ) - """ - resps_list = self.forward_nar_masked( - text_list=text_list, - proms_list=proms_list, - resps_list=resps_list, - task_list=task_list, - lang_list=lang_list, - tone_list=tone_list, - len_list=len_list, - **(sampling_kwargs|{"denoise_start": 0.5}), - ) - """ - # expand if given a raw 1D tensor for i, resp in enumerate(resps_list): if resp.dim() == 1: diff --git a/vall_e/webui.py b/vall_e/webui.py index bb74d4b..9772760 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -202,6 +202,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): # I'm very sure I can procedurally generate this list parser.add_argument("--text", type=str, default=kwargs["text"]) parser.add_argument("--task", type=str, default="tts") + parser.add_argument("--modality", type=str, default=kwargs["modality"]) parser.add_argument("--references", type=str, default=kwargs["reference"]) parser.add_argument("--language", type=str, default=kwargs["language"]) parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"]) @@ -258,16 +259,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): tts = init_tts() - gr.Info("Inferencing...") - - # icky - modality = kwargs.get("modality") - if modality: - for name, engine in tts.engines.items(): - if modality == "AR+NAR": - engine.hyper_config.capabilities = ["ar", "nar"] - elif modality == "NAR-len": - engine.hyper_config.capabilities = ["nar", "len"] + gr.Info(f"Inferencing... (Modality: {tts.modality(args.modality.lower())})") sampling_kwargs = dict( max_steps=args.max_steps, @@ -293,12 +285,13 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): input_prompt_length=args.input_prompt_length, cfg_strength=args.cfg_strength, ) - + with timer("Inferenced in", callback=lambda msg: gr.Info( msg )) as t: wav, sr = tts.inference( text=args.text, language=args.language, task=args.task, + modality=args.modality.lower(), references=args.references.split(";") if args.references is not None else [], **sampling_kwargs, ) @@ -438,8 +431,9 @@ with ui: layout["inference_tts"]["inputs"]["ar-temperature"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)") layout["inference_tts"]["inputs"]["nar-temperature"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)") with gr.Row(): - layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=3.0, minimum=0.0, maximum=14.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale") + layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=1.0, minimum=0.0, maximum=14.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale") layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en") + layout["inference_tts"]["inputs"]["modality"] = gr.Dropdown(value="Auto", choices=["Auto", "AR+NAR", "NAR-len"], label="Modality", info="Whether to inference with the AR+NAR or through the NAR-len.") with gr.Tab("Sampler Settings"): with gr.Row(): layout["inference_tts"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.") @@ -464,7 +458,6 @@ with ui: with gr.Row(): layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.") layout["inference_tts"]["inputs"]["prefix-silence"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Silence Prefix Duration", info="Amount of silence to prefix to the output response before beginning inference.") - layout["inference_tts"]["inputs"]["modality"] = gr.Dropdown(value="Auto", choices=["Auto", "AR+NAR", "NAR-len"], label="Modality", info="Whether to inference with the AR+NAR or through the NAR-len.") with gr.Row(): layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.") layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")