diff --git a/vall_e/config.py b/vall_e/config.py index d026e49..3e77416 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -164,6 +164,7 @@ class Model: training: bool = True interleave: bool = False frozen_params: list[str] = field(default_factory=lambda: []) + p_ar_nar: float = 0.5 @property def full_name(self): diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index dc73470..9a00646 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -88,7 +88,7 @@ class AR_NAR(Base): # is training if n_levels == self.n_resp_levels: - if random.random() < 0.95: + if random.random() < cfg.models.ar_nar.p_ar_nar: quant_levels = None targ_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels diff --git a/vall_e/webui.py b/vall_e/webui.py index b87356a..9872485 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -81,6 +81,9 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): tmp = tempfile.NamedTemporaryFile(suffix='.wav') + if not args.references: + raise ValueError("No reference audio provided.") + tts = init_tts() with timer() as t: wav, sr = tts.inference(