diff --git a/data/demo/index.template.html b/data/demo/index.template.html
index 23eeddc..c9c3712 100644
--- a/data/demo/index.template.html
+++ b/data/demo/index.template.html
@@ -13,7 +13,7 @@
Prompt |
Our VALL-E |
Original VALL-E |
- YourTTS |
+ F5-TTS |
Ground Truth |
diff --git a/vall_e/demo.py b/vall_e/demo.py
index 2003ffe..175e16f 100644
--- a/vall_e/demo.py
+++ b/vall_e/demo.py
@@ -33,7 +33,7 @@ from .emb.qnt import decode_to_file
from tqdm import tqdm, trange
def encode(path):
- if path is None or path.exists():
+ if path is None or not path.exists():
return ""
return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8')
@@ -117,11 +117,11 @@ def main():
# to-do: just make this mappable
if args.comparison == "lora":
- comparison_kwargs["suffix"] = "lora"
- comparison_kwargs["titles"] = ["No LoRA", "LoRA"]
+ comparison_kwargs["suffix"] = "no_lora"
+ comparison_kwargs["titles"] = ["LoRA", "No LoRA"]
- comparison_kwargs["disabled"]["use_lora"] = False
- comparison_kwargs["enabled"]["use_lora"] = True
+ comparison_kwargs["disabled"]["use_lora"] = True
+ comparison_kwargs["enabled"]["use_lora"] = False
elif args.comparison == "entropix-sampling":
comparison_kwargs["suffix"] = "entropix_sampling"
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
@@ -175,7 +175,7 @@ def main():
comparison_kwargs["disabled"]["amp"] = current_amp
comparison_kwargs["enabled"]["amp"] = other_amp
- else:
+ elif args.comparison:
raise Exception(f"Unrecognized comparison flag: {args.comparison}")
# read html template
@@ -221,6 +221,7 @@ def main():
num = args.dataset_samples if args.dataset_samples else length
for i in trange( num, desc="Sampling dataset for samples" ):
+ index = i if not cfg.dataset.sample_shuffle else random.randint( i, length )
batch = dataloader.dataset[i]
dir = args.demo_dir / args.dataset_dir_name / f'{i}'
diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py
index d7b54bd..58d0c69 100644
--- a/vall_e/models/ar_nar.py
+++ b/vall_e/models/ar_nar.py
@@ -302,8 +302,8 @@ class AR_NAR(Base):
# to-do: tune these values, maybe have it factor based on confidence scores or something
if low_temperature:
enabled = n < low_temperature_range
- sampling_repetition_penalty = 1.35 if enabled else original_sampling_repetition_penalty
- sampling_repetition_penalty_decay = 0.5 if enabled else original_sampling_repetition_penalty_decay
+ sampling_repetition_penalty = 1.5 if enabled else original_sampling_repetition_penalty
+ sampling_repetition_penalty_decay = 0.0 if enabled else original_sampling_repetition_penalty_decay
sampling_temperature = original_sampling_temperature if enabled else 1.0
inputs = self.inputs(
diff --git a/vall_e/samplers.py b/vall_e/samplers.py
index 419325a..7c4794a 100644
--- a/vall_e/samplers.py
+++ b/vall_e/samplers.py
@@ -11,7 +11,7 @@ from dataclasses import asdict, dataclass, field
# Simple filter to modify a token's probability if it shows up in the past
# `one_time` will only apply the penalty once
# `decay` is a factor that will exponentially apply to how far away it is
-def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=True ):
+def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=False ):
if factor == 1.0 or previous is None:
return logits