tweaks to prompt duration to allow me to divorce how i use it for training with how I'm using it for the demo page, and demo page tweaks to make my life easier

This commit is contained in:
mrq 2024-12-17 19:33:04 -06:00
parent 7129582303
commit ed152f78df
4 changed files with 57 additions and 66 deletions

View File

@ -5,39 +5,7 @@
<body>
<h1>VALL-E Demo</h1>
<p>${PREAMBLE}</p>
<table>
<thead>
<caption>LibriSpeech</caption>
<tr>
<th>Text</th>
<th>WER↓</th>
<th>CER↓</th>
<th>SIM-O↑</th>
<th>Prompt</th>
<th>Our VALL-E</th>
<th>Original VALL-E</th>
<!--th>F5-TTS</th-->
<th>Ground Truth</th>
</tr>
</thead>
<tbody>${LIBRISPEECH_SAMPLES}</tbody>
</table>
<table>
<thead>
<caption>Sampled Dataset</caption>
<tr>
<th>Text</th>
<th>WER↓</th>
<th>CER↓</th>
<th>SIM-O↑</th>
<th>Prompt</th>
<th>Our VALL-E</th>
<!--th>F5-TTS</th-->
<th>Ground Truth</th>
</tr>
</thead>
<tbody>${DATASET_SAMPLES}</tbody>
</table>
${TABLES}
<p>Settings used: <pre>${SETTINGS}</pre></p>
</body>
</html>

View File

@ -1058,7 +1058,7 @@ class Dataset(_Dataset):
def sample_prompts(self, spkr_name, reference, should_trim=True):
# return no prompt if explicitly requested for who knows why
# or if there's no other speakers to sample from (Emilia has a lot of singleton speakers, but I still want to make use of them)
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0 or len(self.paths_by_spkr_name[spkr_name]) <= 1:
if len(self.paths_by_spkr_name[spkr_name]) <= 1:
return None
prom_list = []
@ -1075,9 +1075,15 @@ class Dataset(_Dataset):
)
"""
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[1] <= 0:
should_trim = False
prom_length = 0
duration_lo, duration_hi = cfg.dataset.prompt_duration_range
trim_length = int(random.uniform(duration_lo, duration_hi) * cfg.dataset.frames_per_second) if trim else 0
if should_trim:
duration_lo, duration_hi = cfg.dataset.prompt_duration_range
trim_length = int(random.uniform(duration_lo, duration_hi) * cfg.dataset.frames_per_second)
else:
trim_length = 0
for _ in range(cfg.dataset.prompt_max_samples):
if reference is not None:

View File

@ -278,16 +278,24 @@ def main():
html = html.replace(r"${SETTINGS}", str(sampling_kwargs))
# pull from provided samples
samples_dirs = {
"librispeech": args.demo_dir / "librispeech",
}
samples_dirs = {}
if (args.demo_dir / args.dataset_dir_name).exists():
# only add the existing librispeech validation dataset if i'm doing validation so I can stop commenting this out
if "validation" in args.dataset_dir_name:
sample_dir["librispeech"] = args.demo_dir / "librispeech",
# automatically pull from anything under the dataset dir
if args.dataset_dir_name.endswith("/*"):
args.dataset_dir_name = args.dataset_dir_name[:-2]
datasets = [ dir for dir in (args.demo_dir / args.dataset_dir_name).iterdir() if dir.is_dir() ]
for path in datasets:
samples_dirs[path.name] = path
# user provided dataset
elif (args.demo_dir / args.dataset_dir_name).exists():
samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
# pull from dataset samples
if args.sample_from_dataset:
cfg.dataset.cache = False
cfg.dataset.sample_type = "path" if len(cfg.dataset.training) < cfg.evaluation.batch_size else "speaker"
cfg.dataset.sample_order = "random"
cfg.dataset.tasks_list = [ 'tts' ]
@ -335,18 +343,17 @@ def main():
outputs = []
metrics_inputs = []
comparison_inputs = []
for k, sample_dir in samples_dirs.items():
for dataset_name, sample_dir in samples_dirs.items():
if not sample_dir.exists():
continue
samples = []
speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
speakers.sort()
#sources = [ "ms_valle", "f5" ] if k == "librispeech" else ["f5"]
sources = [ "ms_valle" ] if k == "librispeech" else []
sources = [ "ms_valle", "f5" ] if dataset_name == "librispeech" else []
# generate demo output
for dir in tqdm(speakers, desc=f"Generating demo for {k}"):
for dir in tqdm(speakers, desc=f"Preparing demo for {dataset_name}"):
text = open(dir / "prompt.txt", encoding="utf-8").read()
language = open(dir / "language.txt").read() if (dir / "language.txt").exists() else "en"
prompt = dir / "prompt.wav"
@ -368,7 +375,7 @@ def main():
cmd = f'python inference-cli.py --model "F5-TTS" --ref_audio "{reference}" --gen_text "{text}" --output_dir "{out_path.parent}"'
"""
if not args.random_prompts or k == "librispeech":
if not args.random_prompts or dataset_name == "librispeech":
audio_samples += [ reference ]
samples.append((
@ -383,16 +390,16 @@ def main():
if should_generate:
comparison_inputs.append((text, prompt, language, out_path_comparison))
metrics_inputs.append((text, language, out_path_comparison, prompt, reference, metrics_path))
metrics_inputs.append((dataset_name, text, language, out_path_comparison, prompt, reference, metrics_path))
should_generate = (args.skip_existing and not out_path.exists()) or not (args.skip_existing)
if should_generate:
inputs.append((text, prompt, language, out_path))
metrics_inputs.append((text, language, out_path, prompt, reference, metrics_path))
metrics_inputs.append((dataset_name, text, language, out_path, prompt, reference, metrics_path))
outputs.append((k, samples))
outputs.append((dataset_name, samples))
if inputs:
process_batch( tts, inputs, sampling_kwargs | (comparison_kwargs["disabled"] if args.comparison else {}) )
@ -401,28 +408,32 @@ def main():
process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) )
metrics_map = {}
for text, language, out_path, prompt_path, reference_path, metrics_path in tqdm(metrics_inputs, desc="Calculating metrics"):
for dataset_name, text, language, out_path, prompt_path, reference_path, metrics_path in tqdm(metrics_inputs, desc="Calculating metrics"):
calculate = not metrics_path.exists() or (metrics_path.stat().st_mtime < out_path.stat().st_mtime)
if calculate:
wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model )
sim_o_score = sim_o( out_path, prompt_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
#sim_o_r_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score} # , "sim-o-r": sim_o_r_score}
metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score}
json_write( metrics, metrics_path )
else:
metrics = json_read( metrics_path )
wer_score, cer_score, sim_o_score = metrics["wer"], metrics["cer"], metrics["sim-o"]
metrics_map[out_path] = (wer_score, cer_score, sim_o_score)
if dataset_name not in metrics_map:
metrics_map[dataset_name] = {}
metrics_map[dataset_name][out_path] = (wer_score, cer_score, sim_o_score)
# collate entries into HTML
for k, samples in outputs:
tables = []
for dataset_name, samples in outputs:
table = "\t\t<h3>${DATASET_NAME}</h3>\n\t\t<p><b>Average WER:</b> ${WER}<br><b>Average CER:</b> ${CER}<br><b>Average SIM-O:</b> ${SIM-O}<br></p>\n\t\t<table>\n\t\t\t<thead>\n\t\t\t\t<tr>\n\t\t\t\t\t<th>Text</th>\n\t\t\t\t\t<th>WER↓</th>\n\t\t\t\t\t<th>CER↓</th>\n\t\t\t\t\t<th>SIM-O↑</th>\n\t\t\t\t\t<th>Prompt</th>\n\t\t\t\t\t<th>Our VALL-E</th>\n\t\t\t\t\t<!--th>Original VALL-E</th-->\n\t\t\t\t\t<!--th>F5-TTS</th-->\n\t\t\t\t\t<th>Ground Truth</th>\n\t\t\t\t</tr>\n\t\t\t</thead>\n\t\t\t<tbody>${SAMPLES}</tbody>\n\t\t</table>"
samples = [
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
"".join([
f'\n\t\t\t\t<td>{metrics_map[audios[1]][0]:.3f}</td><td>{metrics_map[audios[1]][1]:.3f}</td><td>{metrics_map[audios[1]][2]:.3f}</td>'
f'\n\t\t\t\t<td>{metrics_map[dataset_name][audios[1]][0]:.3f}</td><td>{metrics_map[dataset_name][audios[1]][1]:.3f}</td><td>{metrics_map[dataset_name][audios[1]][2]:.3f}</td>'
] ) +
"".join( [
f'\n\t\t\t\t<td><audio controls="controls" preload="none"><source src="{str(audio).replace(str(args.demo_dir), args.audio_path_root) if args.audio_path_root else encode(audio)}"/></audio></td>'
@ -433,11 +444,15 @@ def main():
]
# write audio into template
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
table = table.replace("${WER}", f'{mean([ metrics[0] for metrics in metrics_map[dataset_name].values() ]):.3f}' )
table = table.replace("${CER}", f'{mean([ metrics[1] for metrics in metrics_map[dataset_name].values() ]):.3f}' )
table = table.replace("${SIM-O}", f'{mean([ metrics[2] for metrics in metrics_map[dataset_name].values() ]):.3f}' )
table = table.replace("${DATASET_NAME}", dataset_name)
table = table.replace("${SAMPLES}", "\n".join( samples ) )
tables.append( table )
html = html.replace("${WER}", f'{mean([ metrics[0] for metrics in metrics_map.values() ]):.3f}' )
html = html.replace("${CER}", f'{mean([ metrics[1] for metrics in metrics_map.values() ]):.3f}' )
html = html.replace("${SIM-O}", f'{mean([ metrics[2] for metrics in metrics_map.values() ]):.3f}' )
html = html.replace("${TABLES}", "\n".join( tables ))
if args.comparison:
disabled, enabled = comparison_kwargs["titles"]

View File

@ -11,6 +11,8 @@ from torch import Tensor
from einops import rearrange
from pathlib import Path
from tqdm import tqdm, trange
from .emb import g2p, qnt
from .emb.qnt import trim, trim_random, unload_model, repeat_extend_audio
from .emb.transcribe import transcribe
@ -213,7 +215,7 @@ class TTS():
input_prompt_length = sampling_kwargs.pop("input_prompt_length", 0)
modality = sampling_kwargs.pop("modality", "auto")
seed = sampling_kwargs.pop("seed", None)
tqdm = sampling_kwargs.pop("tqdm", True)
use_tqdm = sampling_kwargs.pop("tqdm", True)
use_lora = sampling_kwargs.pop("use_lora", None)
dtype = sampling_kwargs.pop("dtype", self.dtype)
amp = sampling_kwargs.pop("amp", self.amp)
@ -256,7 +258,7 @@ class TTS():
inputs = []
# tensorfy inputs
for i in range( samples ):
for i in trange( samples, desc="Preparing batches" ):
# detect language
if languages[i] == "auto":
languages[i] = g2p.detect_language( texts[i] )
@ -295,14 +297,14 @@ class TTS():
buffer = ([], [], [], [])
wavs = []
for texts, proms, langs, out_paths in batches:
for texts, proms, langs, out_paths in tqdm(batches, desc="Processing batch"):
seed = set_seed(seed)
batch_size = len(texts)
input_kwargs = dict(
text_list=texts,
proms_list=proms,
lang_list=langs,
disable_tqdm=not tqdm,
disable_tqdm=not use_tqdm,
use_lora=use_lora,
)
@ -355,7 +357,7 @@ class TTS():
input_prompt_length = sampling_kwargs.pop("input_prompt_length", 0)
modality = sampling_kwargs.pop("modality", "auto")
seed = sampling_kwargs.pop("seed", None)
tqdm = sampling_kwargs.pop("tqdm", True)
use_tqdm = sampling_kwargs.pop("tqdm", True)
use_lora = sampling_kwargs.pop("use_lora", None)
dtype = sampling_kwargs.pop("dtype", self.dtype)
amp = sampling_kwargs.pop("amp", self.amp)
@ -405,7 +407,7 @@ class TTS():
if model is not None:
text_list = model(
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=["stt"],
disable_tqdm=not tqdm,
disable_tqdm=not use_tqdm,
use_lora=use_lora,
**sampling_kwargs,
)
@ -452,7 +454,7 @@ class TTS():
text_list=[phns],
proms_list=[prom],
lang_list=[lang],
disable_tqdm=not tqdm,
disable_tqdm=not use_tqdm,
use_lora=use_lora,
)
if model_len is not None: