sort batches to try and reduce number of padded tokens in batched inference (also commented out F5 samples getting added to the demo page because I would have to regenerate them)
This commit is contained in:
parent
20b87bfbd0
commit
cddf8ca814
|
@ -16,7 +16,7 @@
|
||||||
<th>Prompt</th>
|
<th>Prompt</th>
|
||||||
<th>Our VALL-E</th>
|
<th>Our VALL-E</th>
|
||||||
<th>Original VALL-E</th>
|
<th>Original VALL-E</th>
|
||||||
<th>F5-TTS</th>
|
<!--th>F5-TTS</th-->
|
||||||
<th>Ground Truth</th>
|
<th>Ground Truth</th>
|
||||||
</tr>
|
</tr>
|
||||||
</thead>
|
</thead>
|
||||||
|
@ -32,7 +32,7 @@
|
||||||
<th>SIM-O↑</th>
|
<th>SIM-O↑</th>
|
||||||
<th>Prompt</th>
|
<th>Prompt</th>
|
||||||
<th>Our VALL-E</th>
|
<th>Our VALL-E</th>
|
||||||
<th>F5-TTS</th>
|
<!--th>F5-TTS</th-->
|
||||||
<th>Ground Truth</th>
|
<th>Ground Truth</th>
|
||||||
</tr>
|
</tr>
|
||||||
</thead>
|
</thead>
|
||||||
|
|
|
@ -341,7 +341,9 @@ def main():
|
||||||
|
|
||||||
samples = []
|
samples = []
|
||||||
speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
|
speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
|
||||||
sources = [ "ms_valle", "f5" ] if k == "librispeech" else ["f5"]
|
speakers.sort()
|
||||||
|
#sources = [ "ms_valle", "f5" ] if k == "librispeech" else ["f5"]
|
||||||
|
sources = [ "ms_valle" ] if k == "librispeech" else []
|
||||||
|
|
||||||
# generate demo output
|
# generate demo output
|
||||||
for dir in tqdm(speakers, desc=f"Generating demo for {k}"):
|
for dir in tqdm(speakers, desc=f"Generating demo for {k}"):
|
||||||
|
@ -376,19 +378,19 @@ def main():
|
||||||
|
|
||||||
# segregate comparisons into its own batch because they use different kwargs (and I do not support variadic-batched kwargs)
|
# segregate comparisons into its own batch because they use different kwargs (and I do not support variadic-batched kwargs)
|
||||||
if args.comparison:
|
if args.comparison:
|
||||||
should_generate = (args.skip_existing and not out_path.exists()) or not (args.skip_existing)
|
should_generate = (args.skip_existing and not out_path_comparison.exists()) or not (args.skip_existing)
|
||||||
|
|
||||||
if should_generate:
|
if should_generate:
|
||||||
comparison_inputs.append((text, prompt, language, out_path_comparison))
|
comparison_inputs.append((text, prompt, language, out_path_comparison))
|
||||||
|
|
||||||
metrics_inputs.append((text, language, out_path_comparison, reference, metrics_path))
|
metrics_inputs.append((text, language, out_path_comparison, prompt, reference, metrics_path))
|
||||||
|
|
||||||
should_generate = (args.skip_existing and not out_path.exists()) or not (args.skip_existing)
|
should_generate = (args.skip_existing and not out_path.exists()) or not (args.skip_existing)
|
||||||
|
|
||||||
if should_generate:
|
if should_generate:
|
||||||
inputs.append((text, prompt, language, out_path))
|
inputs.append((text, prompt, language, out_path))
|
||||||
|
|
||||||
metrics_inputs.append((text, language, out_path, reference, metrics_path))
|
metrics_inputs.append((text, language, out_path, prompt, reference, metrics_path))
|
||||||
|
|
||||||
outputs.append((k, samples))
|
outputs.append((k, samples))
|
||||||
|
|
||||||
|
@ -399,12 +401,12 @@ def main():
|
||||||
process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) )
|
process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) )
|
||||||
|
|
||||||
metrics_map = {}
|
metrics_map = {}
|
||||||
for text, language, out_path, reference_path, metrics_path in tqdm(metrics_inputs, desc="Calculating metrics"):
|
for text, language, out_path, prompt_path, reference_path, metrics_path in tqdm(metrics_inputs, desc="Calculating metrics"):
|
||||||
calculate = not metrics_path.exists() or (metrics_path.stat().st_mtime < out_path.stat().st_mtime)
|
calculate = not metrics_path.exists() or (metrics_path.stat().st_mtime < out_path.stat().st_mtime)
|
||||||
|
|
||||||
if calculate:
|
if calculate:
|
||||||
wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model )
|
wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model )
|
||||||
sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
|
sim_o_score = sim_o( out_path, prompt_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
|
||||||
|
|
||||||
metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score}
|
metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score}
|
||||||
json_write( metrics, metrics_path )
|
json_write( metrics, metrics_path )
|
||||||
|
|
|
@ -252,6 +252,7 @@ class TTS():
|
||||||
if not text_languages:
|
if not text_languages:
|
||||||
text_languages = languages
|
text_languages = languages
|
||||||
|
|
||||||
|
inputs = []
|
||||||
# tensorfy inputs
|
# tensorfy inputs
|
||||||
for i in range( samples ):
|
for i in range( samples ):
|
||||||
# detect language
|
# detect language
|
||||||
|
@ -266,17 +267,24 @@ class TTS():
|
||||||
references[i] = to_device(references[i], device=self.device, dtype=torch.int16)
|
references[i] = to_device(references[i], device=self.device, dtype=torch.int16)
|
||||||
languages[i] = to_device(languages[i], device=self.device, dtype=torch.uint8)
|
languages[i] = to_device(languages[i], device=self.device, dtype=torch.uint8)
|
||||||
|
|
||||||
|
seq_len = texts[i].shape[0] + 1 + (references[i].shape[0] if references[i] is not None else 0) + 1
|
||||||
|
|
||||||
|
inputs.append((texts[i], references[i], languages[i], out_paths[i], seq_len))
|
||||||
|
|
||||||
|
# attempt to reduce padding
|
||||||
|
inputs.sort(key=lambda x: x[-1])
|
||||||
|
|
||||||
# create batches
|
# create batches
|
||||||
batches = []
|
batches = []
|
||||||
buffer = ([], [], [], [])
|
buffer = ([], [], [], [])
|
||||||
for batch in zip( texts, references, languages, out_paths ):
|
for batch in inputs:
|
||||||
# flush
|
# flush
|
||||||
if len(buffer[0]) >= batch_size:
|
if len(buffer[0]) >= batch_size:
|
||||||
batches.append(buffer)
|
batches.append(buffer)
|
||||||
buffer = ([], [], [], [])
|
buffer = ([], [], [], [])
|
||||||
|
|
||||||
# insert into buffer
|
# insert into buffer
|
||||||
for i, x in enumerate( batch ):
|
for i, x in enumerate( batch[:-1] ):
|
||||||
buffer[i].append(x)
|
buffer[i].append(x)
|
||||||
|
|
||||||
# flush
|
# flush
|
||||||
|
|
|
@ -1546,7 +1546,7 @@ class Base(nn.Module):
|
||||||
casual_levels = [ "AR:0:0", "stt", "len" ]
|
casual_levels = [ "AR:0:0", "stt", "len" ]
|
||||||
|
|
||||||
# right now limit to new versions because I need to retrain the model for noncausal masks...
|
# right now limit to new versions because I need to retrain the model for noncausal masks...
|
||||||
is_causal = [ l in casual_levels for l in classifier_levels ] if self.noncausal_masks else None
|
is_causal = [ l in casual_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ]
|
||||||
|
|
||||||
output = self._forward(
|
output = self._forward(
|
||||||
inputs=x,
|
inputs=x,
|
||||||
|
|
|
@ -13,7 +13,7 @@ from .utils import clamp
|
||||||
# Simple filter to modify a token's probability if it shows up in the past
|
# Simple filter to modify a token's probability if it shows up in the past
|
||||||
# `one_time` will only apply the penalty once
|
# `one_time` will only apply the penalty once
|
||||||
# `decay` is a factor that will exponentially apply to how far away it is
|
# `decay` is a factor that will exponentially apply to how far away it is
|
||||||
def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=True, limit=None ):
|
def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=False, limit=None ):
|
||||||
if factor == 1.0 or previous is None:
|
if factor == 1.0 or previous is None:
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user