From 1a02cd5bce9c164f0eaf0cd95254ffee4b36709f Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 21 Oct 2024 19:52:02 -0500 Subject: [PATCH] modify demo template to say F5 instead of YourTTS, swap LoRA comparison around to make the lora'd the base file, and the no-lora the suffix'd file --- data/demo/index.template.html | 2 +- vall_e/demo.py | 13 +++++++------ vall_e/models/ar_nar.py | 4 ++-- vall_e/samplers.py | 2 +- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/data/demo/index.template.html b/data/demo/index.template.html index 23eeddc..c9c3712 100644 --- a/data/demo/index.template.html +++ b/data/demo/index.template.html @@ -13,7 +13,7 @@ Prompt Our VALL-E Original VALL-E - YourTTS + F5-TTS Ground Truth diff --git a/vall_e/demo.py b/vall_e/demo.py index 2003ffe..175e16f 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -33,7 +33,7 @@ from .emb.qnt import decode_to_file from tqdm import tqdm, trange def encode(path): - if path is None or path.exists(): + if path is None or not path.exists(): return "" return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8') @@ -117,11 +117,11 @@ def main(): # to-do: just make this mappable if args.comparison == "lora": - comparison_kwargs["suffix"] = "lora" - comparison_kwargs["titles"] = ["No LoRA", "LoRA"] + comparison_kwargs["suffix"] = "no_lora" + comparison_kwargs["titles"] = ["LoRA", "No LoRA"] - comparison_kwargs["disabled"]["use_lora"] = False - comparison_kwargs["enabled"]["use_lora"] = True + comparison_kwargs["disabled"]["use_lora"] = True + comparison_kwargs["enabled"]["use_lora"] = False elif args.comparison == "entropix-sampling": comparison_kwargs["suffix"] = "entropix_sampling" comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"] @@ -175,7 +175,7 @@ def main(): comparison_kwargs["disabled"]["amp"] = current_amp comparison_kwargs["enabled"]["amp"] = other_amp - else: + elif args.comparison: raise Exception(f"Unrecognized comparison flag: {args.comparison}") # read html template @@ -221,6 +221,7 @@ def main(): num = args.dataset_samples if args.dataset_samples else length for i in trange( num, desc="Sampling dataset for samples" ): + index = i if not cfg.dataset.sample_shuffle else random.randint( i, length ) batch = dataloader.dataset[i] dir = args.demo_dir / args.dataset_dir_name / f'{i}' diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index d7b54bd..58d0c69 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -302,8 +302,8 @@ class AR_NAR(Base): # to-do: tune these values, maybe have it factor based on confidence scores or something if low_temperature: enabled = n < low_temperature_range - sampling_repetition_penalty = 1.35 if enabled else original_sampling_repetition_penalty - sampling_repetition_penalty_decay = 0.5 if enabled else original_sampling_repetition_penalty_decay + sampling_repetition_penalty = 1.5 if enabled else original_sampling_repetition_penalty + sampling_repetition_penalty_decay = 0.0 if enabled else original_sampling_repetition_penalty_decay sampling_temperature = original_sampling_temperature if enabled else 1.0 inputs = self.inputs( diff --git a/vall_e/samplers.py b/vall_e/samplers.py index 419325a..7c4794a 100644 --- a/vall_e/samplers.py +++ b/vall_e/samplers.py @@ -11,7 +11,7 @@ from dataclasses import asdict, dataclass, field # Simple filter to modify a token's probability if it shows up in the past # `one_time` will only apply the penalty once # `decay` is a factor that will exponentially apply to how far away it is -def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=True ): +def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=False ): if factor == 1.0 or previous is None: return logits