sort batches to try and reduce number of padded tokens in batched inference (also commented out F5 samples getting added to the demo page because I would have to regenerate them)
This commit is contained in:
parent
20b87bfbd0
commit
cddf8ca814
|
@ -16,7 +16,7 @@
|
|||
<th>Prompt</th>
|
||||
<th>Our VALL-E</th>
|
||||
<th>Original VALL-E</th>
|
||||
<th>F5-TTS</th>
|
||||
<!--th>F5-TTS</th-->
|
||||
<th>Ground Truth</th>
|
||||
</tr>
|
||||
</thead>
|
||||
|
@ -32,7 +32,7 @@
|
|||
<th>SIM-O↑</th>
|
||||
<th>Prompt</th>
|
||||
<th>Our VALL-E</th>
|
||||
<th>F5-TTS</th>
|
||||
<!--th>F5-TTS</th-->
|
||||
<th>Ground Truth</th>
|
||||
</tr>
|
||||
</thead>
|
||||
|
|
|
@ -341,7 +341,9 @@ def main():
|
|||
|
||||
samples = []
|
||||
speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
|
||||
sources = [ "ms_valle", "f5" ] if k == "librispeech" else ["f5"]
|
||||
speakers.sort()
|
||||
#sources = [ "ms_valle", "f5" ] if k == "librispeech" else ["f5"]
|
||||
sources = [ "ms_valle" ] if k == "librispeech" else []
|
||||
|
||||
# generate demo output
|
||||
for dir in tqdm(speakers, desc=f"Generating demo for {k}"):
|
||||
|
@ -376,19 +378,19 @@ def main():
|
|||
|
||||
# segregate comparisons into its own batch because they use different kwargs (and I do not support variadic-batched kwargs)
|
||||
if args.comparison:
|
||||
should_generate = (args.skip_existing and not out_path.exists()) or not (args.skip_existing)
|
||||
should_generate = (args.skip_existing and not out_path_comparison.exists()) or not (args.skip_existing)
|
||||
|
||||
if should_generate:
|
||||
comparison_inputs.append((text, prompt, language, out_path_comparison))
|
||||
|
||||
metrics_inputs.append((text, language, out_path_comparison, reference, metrics_path))
|
||||
metrics_inputs.append((text, language, out_path_comparison, prompt, reference, metrics_path))
|
||||
|
||||
should_generate = (args.skip_existing and not out_path.exists()) or not (args.skip_existing)
|
||||
|
||||
if should_generate:
|
||||
inputs.append((text, prompt, language, out_path))
|
||||
|
||||
metrics_inputs.append((text, language, out_path, reference, metrics_path))
|
||||
metrics_inputs.append((text, language, out_path, prompt, reference, metrics_path))
|
||||
|
||||
outputs.append((k, samples))
|
||||
|
||||
|
@ -399,12 +401,12 @@ def main():
|
|||
process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) )
|
||||
|
||||
metrics_map = {}
|
||||
for text, language, out_path, reference_path, metrics_path in tqdm(metrics_inputs, desc="Calculating metrics"):
|
||||
for text, language, out_path, prompt_path, reference_path, metrics_path in tqdm(metrics_inputs, desc="Calculating metrics"):
|
||||
calculate = not metrics_path.exists() or (metrics_path.stat().st_mtime < out_path.stat().st_mtime)
|
||||
|
||||
if calculate:
|
||||
wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model )
|
||||
sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
|
||||
sim_o_score = sim_o( out_path, prompt_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
|
||||
|
||||
metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score}
|
||||
json_write( metrics, metrics_path )
|
||||
|
|
|
@ -252,6 +252,7 @@ class TTS():
|
|||
if not text_languages:
|
||||
text_languages = languages
|
||||
|
||||
inputs = []
|
||||
# tensorfy inputs
|
||||
for i in range( samples ):
|
||||
# detect language
|
||||
|
@ -266,17 +267,24 @@ class TTS():
|
|||
references[i] = to_device(references[i], device=self.device, dtype=torch.int16)
|
||||
languages[i] = to_device(languages[i], device=self.device, dtype=torch.uint8)
|
||||
|
||||
seq_len = texts[i].shape[0] + 1 + (references[i].shape[0] if references[i] is not None else 0) + 1
|
||||
|
||||
inputs.append((texts[i], references[i], languages[i], out_paths[i], seq_len))
|
||||
|
||||
# attempt to reduce padding
|
||||
inputs.sort(key=lambda x: x[-1])
|
||||
|
||||
# create batches
|
||||
batches = []
|
||||
buffer = ([], [], [], [])
|
||||
for batch in zip( texts, references, languages, out_paths ):
|
||||
for batch in inputs:
|
||||
# flush
|
||||
if len(buffer[0]) >= batch_size:
|
||||
batches.append(buffer)
|
||||
buffer = ([], [], [], [])
|
||||
|
||||
# insert into buffer
|
||||
for i, x in enumerate( batch ):
|
||||
for i, x in enumerate( batch[:-1] ):
|
||||
buffer[i].append(x)
|
||||
|
||||
# flush
|
||||
|
|
|
@ -1546,7 +1546,7 @@ class Base(nn.Module):
|
|||
casual_levels = [ "AR:0:0", "stt", "len" ]
|
||||
|
||||
# right now limit to new versions because I need to retrain the model for noncausal masks...
|
||||
is_causal = [ l in casual_levels for l in classifier_levels ] if self.noncausal_masks else None
|
||||
is_causal = [ l in casual_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ]
|
||||
|
||||
output = self._forward(
|
||||
inputs=x,
|
||||
|
|
|
@ -13,7 +13,7 @@ from .utils import clamp
|
|||
# Simple filter to modify a token's probability if it shows up in the past
|
||||
# `one_time` will only apply the penalty once
|
||||
# `decay` is a factor that will exponentially apply to how far away it is
|
||||
def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=True, limit=None ):
|
||||
def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=False, limit=None ):
|
||||
if factor == 1.0 or previous is None:
|
||||
return logits
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user