sort batches to try and reduce number of padded tokens in batched inference (also commented out F5 samples getting added to the demo page because I would have to regenerate them)

This commit is contained in:
mrq 2024-12-11 22:45:38 -06:00
parent 20b87bfbd0
commit cddf8ca814
5 changed files with 22 additions and 12 deletions

View File

@ -16,7 +16,7 @@
<th>Prompt</th>
<th>Our VALL-E</th>
<th>Original VALL-E</th>
<th>F5-TTS</th>
<!--th>F5-TTS</th-->
<th>Ground Truth</th>
</tr>
</thead>
@ -32,7 +32,7 @@
<th>SIM-O↑</th>
<th>Prompt</th>
<th>Our VALL-E</th>
<th>F5-TTS</th>
<!--th>F5-TTS</th-->
<th>Ground Truth</th>
</tr>
</thead>

View File

@ -341,7 +341,9 @@ def main():
samples = []
speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
sources = [ "ms_valle", "f5" ] if k == "librispeech" else ["f5"]
speakers.sort()
#sources = [ "ms_valle", "f5" ] if k == "librispeech" else ["f5"]
sources = [ "ms_valle" ] if k == "librispeech" else []
# generate demo output
for dir in tqdm(speakers, desc=f"Generating demo for {k}"):
@ -376,19 +378,19 @@ def main():
# segregate comparisons into its own batch because they use different kwargs (and I do not support variadic-batched kwargs)
if args.comparison:
should_generate = (args.skip_existing and not out_path.exists()) or not (args.skip_existing)
should_generate = (args.skip_existing and not out_path_comparison.exists()) or not (args.skip_existing)
if should_generate:
comparison_inputs.append((text, prompt, language, out_path_comparison))
metrics_inputs.append((text, language, out_path_comparison, reference, metrics_path))
metrics_inputs.append((text, language, out_path_comparison, prompt, reference, metrics_path))
should_generate = (args.skip_existing and not out_path.exists()) or not (args.skip_existing)
if should_generate:
inputs.append((text, prompt, language, out_path))
metrics_inputs.append((text, language, out_path, reference, metrics_path))
metrics_inputs.append((text, language, out_path, prompt, reference, metrics_path))
outputs.append((k, samples))
@ -399,12 +401,12 @@ def main():
process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) )
metrics_map = {}
for text, language, out_path, reference_path, metrics_path in tqdm(metrics_inputs, desc="Calculating metrics"):
for text, language, out_path, prompt_path, reference_path, metrics_path in tqdm(metrics_inputs, desc="Calculating metrics"):
calculate = not metrics_path.exists() or (metrics_path.stat().st_mtime < out_path.stat().st_mtime)
if calculate:
wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model )
sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
sim_o_score = sim_o( out_path, prompt_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score}
json_write( metrics, metrics_path )

View File

@ -252,6 +252,7 @@ class TTS():
if not text_languages:
text_languages = languages
inputs = []
# tensorfy inputs
for i in range( samples ):
# detect language
@ -266,17 +267,24 @@ class TTS():
references[i] = to_device(references[i], device=self.device, dtype=torch.int16)
languages[i] = to_device(languages[i], device=self.device, dtype=torch.uint8)
seq_len = texts[i].shape[0] + 1 + (references[i].shape[0] if references[i] is not None else 0) + 1
inputs.append((texts[i], references[i], languages[i], out_paths[i], seq_len))
# attempt to reduce padding
inputs.sort(key=lambda x: x[-1])
# create batches
batches = []
buffer = ([], [], [], [])
for batch in zip( texts, references, languages, out_paths ):
for batch in inputs:
# flush
if len(buffer[0]) >= batch_size:
batches.append(buffer)
buffer = ([], [], [], [])
# insert into buffer
for i, x in enumerate( batch ):
for i, x in enumerate( batch[:-1] ):
buffer[i].append(x)
# flush

View File

@ -1546,7 +1546,7 @@ class Base(nn.Module):
casual_levels = [ "AR:0:0", "stt", "len" ]
# right now limit to new versions because I need to retrain the model for noncausal masks...
is_causal = [ l in casual_levels for l in classifier_levels ] if self.noncausal_masks else None
is_causal = [ l in casual_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ]
output = self._forward(
inputs=x,

View File

@ -13,7 +13,7 @@ from .utils import clamp
# Simple filter to modify a token's probability if it shows up in the past
# `one_time` will only apply the penalty once
# `decay` is a factor that will exponentially apply to how far away it is
def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=True, limit=None ):
def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=False, limit=None ):
if factor == 1.0 or previous is None:
return logits