diff --git a/data/demo/index.template.html b/data/demo/index.template.html index e5efdff..592788c 100644 --- a/data/demo/index.template.html +++ b/data/demo/index.template.html @@ -16,7 +16,7 @@ Prompt Our VALL-E Original VALL-E - F5-TTS + Ground Truth @@ -32,7 +32,7 @@ SIM-O↑ Prompt Our VALL-E - F5-TTS + Ground Truth diff --git a/vall_e/demo.py b/vall_e/demo.py index 8e626e6..c5eeb95 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -341,7 +341,9 @@ def main(): samples = [] speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ] - sources = [ "ms_valle", "f5" ] if k == "librispeech" else ["f5"] + speakers.sort() + #sources = [ "ms_valle", "f5" ] if k == "librispeech" else ["f5"] + sources = [ "ms_valle" ] if k == "librispeech" else [] # generate demo output for dir in tqdm(speakers, desc=f"Generating demo for {k}"): @@ -376,19 +378,19 @@ def main(): # segregate comparisons into its own batch because they use different kwargs (and I do not support variadic-batched kwargs) if args.comparison: - should_generate = (args.skip_existing and not out_path.exists()) or not (args.skip_existing) + should_generate = (args.skip_existing and not out_path_comparison.exists()) or not (args.skip_existing) if should_generate: comparison_inputs.append((text, prompt, language, out_path_comparison)) - metrics_inputs.append((text, language, out_path_comparison, reference, metrics_path)) + metrics_inputs.append((text, language, out_path_comparison, prompt, reference, metrics_path)) should_generate = (args.skip_existing and not out_path.exists()) or not (args.skip_existing) if should_generate: inputs.append((text, prompt, language, out_path)) - metrics_inputs.append((text, language, out_path, reference, metrics_path)) + metrics_inputs.append((text, language, out_path, prompt, reference, metrics_path)) outputs.append((k, samples)) @@ -399,12 +401,12 @@ def main(): process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) ) metrics_map = {} - for text, language, out_path, reference_path, metrics_path in tqdm(metrics_inputs, desc="Calculating metrics"): + for text, language, out_path, prompt_path, reference_path, metrics_path in tqdm(metrics_inputs, desc="Calculating metrics"): calculate = not metrics_path.exists() or (metrics_path.stat().st_mtime < out_path.stat().st_mtime) if calculate: wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model ) - sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model ) + sim_o_score = sim_o( out_path, prompt_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model ) metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score} json_write( metrics, metrics_path ) diff --git a/vall_e/inference.py b/vall_e/inference.py index 5e50e07..eebff5e 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -252,6 +252,7 @@ class TTS(): if not text_languages: text_languages = languages + inputs = [] # tensorfy inputs for i in range( samples ): # detect language @@ -266,17 +267,24 @@ class TTS(): references[i] = to_device(references[i], device=self.device, dtype=torch.int16) languages[i] = to_device(languages[i], device=self.device, dtype=torch.uint8) + seq_len = texts[i].shape[0] + 1 + (references[i].shape[0] if references[i] is not None else 0) + 1 + + inputs.append((texts[i], references[i], languages[i], out_paths[i], seq_len)) + + # attempt to reduce padding + inputs.sort(key=lambda x: x[-1]) + # create batches batches = [] buffer = ([], [], [], []) - for batch in zip( texts, references, languages, out_paths ): + for batch in inputs: # flush if len(buffer[0]) >= batch_size: batches.append(buffer) buffer = ([], [], [], []) # insert into buffer - for i, x in enumerate( batch ): + for i, x in enumerate( batch[:-1] ): buffer[i].append(x) # flush diff --git a/vall_e/models/base.py b/vall_e/models/base.py index c83e383..0b37558 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1546,7 +1546,7 @@ class Base(nn.Module): casual_levels = [ "AR:0:0", "stt", "len" ] # right now limit to new versions because I need to retrain the model for noncausal masks... - is_causal = [ l in casual_levels for l in classifier_levels ] if self.noncausal_masks else None + is_causal = [ l in casual_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ] output = self._forward( inputs=x, diff --git a/vall_e/samplers.py b/vall_e/samplers.py index 77afe8d..567dead 100644 --- a/vall_e/samplers.py +++ b/vall_e/samplers.py @@ -13,7 +13,7 @@ from .utils import clamp # Simple filter to modify a token's probability if it shows up in the past # `one_time` will only apply the penalty once # `decay` is a factor that will exponentially apply to how far away it is -def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=True, limit=None ): +def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=False, limit=None ): if factor == 1.0 or previous is None: return logits