From bcf3910a1778ad2ca5a2cbf5ec3a7cbf893c1c00 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 12 Jun 2024 19:49:47 -0500 Subject: [PATCH] the NAR only dream is dead (it just won't work) --- vall_e/__main__.py | 1 + vall_e/inference.py | 61 ++++++++++++++++++++++++-------------------- vall_e/models/nar.py | 18 +++++++++++++ 3 files changed, 53 insertions(+), 27 deletions(-) diff --git a/vall_e/__main__.py b/vall_e/__main__.py index 1835be0..5054411 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -1,6 +1,7 @@ import argparse from pathlib import Path from .inference import TTS +from .config import cfg def path_list(arg): return [Path(p) for p in arg.split(";")] diff --git a/vall_e/inference.py b/vall_e/inference.py index ed61939..a609dc2 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -141,11 +141,14 @@ class TTS(): sr = None model_ar = None + model_len = None model_nar = None for name, engine in self.engines.items(): if "ar" in engine.hyper_config.capabilities: model_ar = engine.module + if "len" in engine.hyper_config.capabilities: + model_len = engine.module if "nar" in engine.hyper_config.capabilities: model_nar = engine.module @@ -168,33 +171,37 @@ class TTS(): # AR temp: 1 # NAR temp: 0.05 # prom size: 3 - - """ - resps_list = engine(text_list=text_list, proms_list=proms_list, max_steps=max_ar_steps, sampling_temperature=ar_temp) - resps_list = engine(text_list=text_list, proms_list=proms_list, resps_list=resps_list, sampling_temperature=nar_temp) - """ - - resps_list = model_ar( - text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context, - sampling_temperature=ar_temp, - sampling_min_temperature=min_ar_temp, - sampling_top_p=top_p, sampling_top_k=top_k, - sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, - sampling_length_penalty=length_penalty, - sampling_beam_width=beam_width, - sampling_mirostat_tau=mirostat_tau, - sampling_mirostat_eta=mirostat_eta, - ) - resps_list = model_nar( - text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, - max_levels=max_nar_levels, - sampling_temperature=nar_temp, - sampling_min_temperature=min_nar_temp, - sampling_top_p=top_p, sampling_top_k=top_k, - sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, - ) - """ - """ + if model_ar is not None: + resps_list = model_ar( + text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context, + sampling_temperature=ar_temp, + sampling_min_temperature=min_ar_temp, + sampling_top_p=top_p, sampling_top_k=top_k, + sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, + sampling_length_penalty=length_penalty, + sampling_beam_width=beam_width, + sampling_mirostat_tau=mirostat_tau, + sampling_mirostat_eta=mirostat_eta, + ) + resps_list = model_nar( + text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, + max_levels=max_nar_levels, + sampling_temperature=nar_temp, + sampling_min_temperature=min_nar_temp, + sampling_top_p=top_p, sampling_top_k=top_k, + sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, + ) + elif model_len is not None: + len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=10 ) # don't need more than that + resps_list = model_nar( text_list=[phns], proms_list=[prom], len_list=len_list, + max_levels=max_nar_levels, + sampling_temperature=nar_temp, + sampling_min_temperature=min_nar_temp, + sampling_top_p=top_p, sampling_top_k=top_k, + sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, + ) + else: + raise Exception("!") wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device) wavs.append(wav) diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 466537f..ae45723 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -229,7 +229,25 @@ class NAR(Base): quant_levels=quant_levels, ) + """ resps_list = [ logit[-l:].argmax(dim=1) for logit, l in zip(logits, len_list) ] + """ + + resps_list = super().sample( + logits=logits, + resps_list=prev_list, + quant_levels=quant_levels, + + temperature=1.0 if n == 0 else sampling_temperature, + min_temperature=sampling_min_temperature, + top_p=sampling_top_p, + top_k=sampling_top_k, + repetition_penalty=sampling_repetition_penalty, + repetition_penalty_decay=sampling_repetition_penalty_decay, + #length_penalty=sampling_length_penalty, + #beam_width=sampling_beam_width, + #mirostat=mirostat, + ) if n == 0: prev_list = [ r.unsqueeze(-1).to(device) for r in resps_list ]