diff --git a/data/demo/index.template.html b/data/demo/index.template.html
index e5efdff..592788c 100644
--- a/data/demo/index.template.html
+++ b/data/demo/index.template.html
@@ -16,7 +16,7 @@
Prompt |
Our VALL-E |
Original VALL-E |
- F5-TTS |
+
Ground Truth |
@@ -32,7 +32,7 @@
SIM-O↑ |
Prompt |
Our VALL-E |
- F5-TTS |
+
Ground Truth |
diff --git a/vall_e/demo.py b/vall_e/demo.py
index 8e626e6..c5eeb95 100644
--- a/vall_e/demo.py
+++ b/vall_e/demo.py
@@ -341,7 +341,9 @@ def main():
samples = []
speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
- sources = [ "ms_valle", "f5" ] if k == "librispeech" else ["f5"]
+ speakers.sort()
+ #sources = [ "ms_valle", "f5" ] if k == "librispeech" else ["f5"]
+ sources = [ "ms_valle" ] if k == "librispeech" else []
# generate demo output
for dir in tqdm(speakers, desc=f"Generating demo for {k}"):
@@ -376,19 +378,19 @@ def main():
# segregate comparisons into its own batch because they use different kwargs (and I do not support variadic-batched kwargs)
if args.comparison:
- should_generate = (args.skip_existing and not out_path.exists()) or not (args.skip_existing)
+ should_generate = (args.skip_existing and not out_path_comparison.exists()) or not (args.skip_existing)
if should_generate:
comparison_inputs.append((text, prompt, language, out_path_comparison))
- metrics_inputs.append((text, language, out_path_comparison, reference, metrics_path))
+ metrics_inputs.append((text, language, out_path_comparison, prompt, reference, metrics_path))
should_generate = (args.skip_existing and not out_path.exists()) or not (args.skip_existing)
if should_generate:
inputs.append((text, prompt, language, out_path))
- metrics_inputs.append((text, language, out_path, reference, metrics_path))
+ metrics_inputs.append((text, language, out_path, prompt, reference, metrics_path))
outputs.append((k, samples))
@@ -399,12 +401,12 @@ def main():
process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) )
metrics_map = {}
- for text, language, out_path, reference_path, metrics_path in tqdm(metrics_inputs, desc="Calculating metrics"):
+ for text, language, out_path, prompt_path, reference_path, metrics_path in tqdm(metrics_inputs, desc="Calculating metrics"):
calculate = not metrics_path.exists() or (metrics_path.stat().st_mtime < out_path.stat().st_mtime)
if calculate:
wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model )
- sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
+ sim_o_score = sim_o( out_path, prompt_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score}
json_write( metrics, metrics_path )
diff --git a/vall_e/inference.py b/vall_e/inference.py
index 5e50e07..eebff5e 100755
--- a/vall_e/inference.py
+++ b/vall_e/inference.py
@@ -252,6 +252,7 @@ class TTS():
if not text_languages:
text_languages = languages
+ inputs = []
# tensorfy inputs
for i in range( samples ):
# detect language
@@ -266,17 +267,24 @@ class TTS():
references[i] = to_device(references[i], device=self.device, dtype=torch.int16)
languages[i] = to_device(languages[i], device=self.device, dtype=torch.uint8)
+ seq_len = texts[i].shape[0] + 1 + (references[i].shape[0] if references[i] is not None else 0) + 1
+
+ inputs.append((texts[i], references[i], languages[i], out_paths[i], seq_len))
+
+ # attempt to reduce padding
+ inputs.sort(key=lambda x: x[-1])
+
# create batches
batches = []
buffer = ([], [], [], [])
- for batch in zip( texts, references, languages, out_paths ):
+ for batch in inputs:
# flush
if len(buffer[0]) >= batch_size:
batches.append(buffer)
buffer = ([], [], [], [])
# insert into buffer
- for i, x in enumerate( batch ):
+ for i, x in enumerate( batch[:-1] ):
buffer[i].append(x)
# flush
diff --git a/vall_e/models/base.py b/vall_e/models/base.py
index c83e383..0b37558 100755
--- a/vall_e/models/base.py
+++ b/vall_e/models/base.py
@@ -1546,7 +1546,7 @@ class Base(nn.Module):
casual_levels = [ "AR:0:0", "stt", "len" ]
# right now limit to new versions because I need to retrain the model for noncausal masks...
- is_causal = [ l in casual_levels for l in classifier_levels ] if self.noncausal_masks else None
+ is_causal = [ l in casual_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ]
output = self._forward(
inputs=x,
diff --git a/vall_e/samplers.py b/vall_e/samplers.py
index 77afe8d..567dead 100644
--- a/vall_e/samplers.py
+++ b/vall_e/samplers.py
@@ -13,7 +13,7 @@ from .utils import clamp
# Simple filter to modify a token's probability if it shows up in the past
# `one_time` will only apply the penalty once
# `decay` is a factor that will exponentially apply to how far away it is
-def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=True, limit=None ):
+def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=False, limit=None ):
if factor == 1.0 or previous is None:
return logits