tweaks to prompt duration to allow me to divorce how i use it for training with how I'm using it for the demo page, and demo page tweaks to make my life easier

This commit is contained in:
mrq 2024-12-17 19:33:04 -06:00
parent 7129582303
commit ed152f78df
4 changed files with 57 additions and 66 deletions

View File

@ -5,39 +5,7 @@
<body> <body>
<h1>VALL-E Demo</h1> <h1>VALL-E Demo</h1>
<p>${PREAMBLE}</p> <p>${PREAMBLE}</p>
<table> ${TABLES}
<thead>
<caption>LibriSpeech</caption>
<tr>
<th>Text</th>
<th>WER↓</th>
<th>CER↓</th>
<th>SIM-O↑</th>
<th>Prompt</th>
<th>Our VALL-E</th>
<th>Original VALL-E</th>
<!--th>F5-TTS</th-->
<th>Ground Truth</th>
</tr>
</thead>
<tbody>${LIBRISPEECH_SAMPLES}</tbody>
</table>
<table>
<thead>
<caption>Sampled Dataset</caption>
<tr>
<th>Text</th>
<th>WER↓</th>
<th>CER↓</th>
<th>SIM-O↑</th>
<th>Prompt</th>
<th>Our VALL-E</th>
<!--th>F5-TTS</th-->
<th>Ground Truth</th>
</tr>
</thead>
<tbody>${DATASET_SAMPLES}</tbody>
</table>
<p>Settings used: <pre>${SETTINGS}</pre></p> <p>Settings used: <pre>${SETTINGS}</pre></p>
</body> </body>
</html> </html>

View File

@ -1058,7 +1058,7 @@ class Dataset(_Dataset):
def sample_prompts(self, spkr_name, reference, should_trim=True): def sample_prompts(self, spkr_name, reference, should_trim=True):
# return no prompt if explicitly requested for who knows why # return no prompt if explicitly requested for who knows why
# or if there's no other speakers to sample from (Emilia has a lot of singleton speakers, but I still want to make use of them) # or if there's no other speakers to sample from (Emilia has a lot of singleton speakers, but I still want to make use of them)
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0 or len(self.paths_by_spkr_name[spkr_name]) <= 1: if len(self.paths_by_spkr_name[spkr_name]) <= 1:
return None return None
prom_list = [] prom_list = []
@ -1075,9 +1075,15 @@ class Dataset(_Dataset):
) )
""" """
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[1] <= 0:
should_trim = False
prom_length = 0 prom_length = 0
if should_trim:
duration_lo, duration_hi = cfg.dataset.prompt_duration_range duration_lo, duration_hi = cfg.dataset.prompt_duration_range
trim_length = int(random.uniform(duration_lo, duration_hi) * cfg.dataset.frames_per_second) if trim else 0 trim_length = int(random.uniform(duration_lo, duration_hi) * cfg.dataset.frames_per_second)
else:
trim_length = 0
for _ in range(cfg.dataset.prompt_max_samples): for _ in range(cfg.dataset.prompt_max_samples):
if reference is not None: if reference is not None:

View File

@ -278,16 +278,24 @@ def main():
html = html.replace(r"${SETTINGS}", str(sampling_kwargs)) html = html.replace(r"${SETTINGS}", str(sampling_kwargs))
# pull from provided samples # pull from provided samples
samples_dirs = { samples_dirs = {}
"librispeech": args.demo_dir / "librispeech",
}
if (args.demo_dir / args.dataset_dir_name).exists(): # only add the existing librispeech validation dataset if i'm doing validation so I can stop commenting this out
if "validation" in args.dataset_dir_name:
sample_dir["librispeech"] = args.demo_dir / "librispeech",
# automatically pull from anything under the dataset dir
if args.dataset_dir_name.endswith("/*"):
args.dataset_dir_name = args.dataset_dir_name[:-2]
datasets = [ dir for dir in (args.demo_dir / args.dataset_dir_name).iterdir() if dir.is_dir() ]
for path in datasets:
samples_dirs[path.name] = path
# user provided dataset
elif (args.demo_dir / args.dataset_dir_name).exists():
samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
# pull from dataset samples # pull from dataset samples
if args.sample_from_dataset: if args.sample_from_dataset:
cfg.dataset.cache = False
cfg.dataset.sample_type = "path" if len(cfg.dataset.training) < cfg.evaluation.batch_size else "speaker" cfg.dataset.sample_type = "path" if len(cfg.dataset.training) < cfg.evaluation.batch_size else "speaker"
cfg.dataset.sample_order = "random" cfg.dataset.sample_order = "random"
cfg.dataset.tasks_list = [ 'tts' ] cfg.dataset.tasks_list = [ 'tts' ]
@ -335,18 +343,17 @@ def main():
outputs = [] outputs = []
metrics_inputs = [] metrics_inputs = []
comparison_inputs = [] comparison_inputs = []
for k, sample_dir in samples_dirs.items(): for dataset_name, sample_dir in samples_dirs.items():
if not sample_dir.exists(): if not sample_dir.exists():
continue continue
samples = [] samples = []
speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ] speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
speakers.sort() speakers.sort()
#sources = [ "ms_valle", "f5" ] if k == "librispeech" else ["f5"] sources = [ "ms_valle", "f5" ] if dataset_name == "librispeech" else []
sources = [ "ms_valle" ] if k == "librispeech" else []
# generate demo output # generate demo output
for dir in tqdm(speakers, desc=f"Generating demo for {k}"): for dir in tqdm(speakers, desc=f"Preparing demo for {dataset_name}"):
text = open(dir / "prompt.txt", encoding="utf-8").read() text = open(dir / "prompt.txt", encoding="utf-8").read()
language = open(dir / "language.txt").read() if (dir / "language.txt").exists() else "en" language = open(dir / "language.txt").read() if (dir / "language.txt").exists() else "en"
prompt = dir / "prompt.wav" prompt = dir / "prompt.wav"
@ -368,7 +375,7 @@ def main():
cmd = f'python inference-cli.py --model "F5-TTS" --ref_audio "{reference}" --gen_text "{text}" --output_dir "{out_path.parent}"' cmd = f'python inference-cli.py --model "F5-TTS" --ref_audio "{reference}" --gen_text "{text}" --output_dir "{out_path.parent}"'
""" """
if not args.random_prompts or k == "librispeech": if not args.random_prompts or dataset_name == "librispeech":
audio_samples += [ reference ] audio_samples += [ reference ]
samples.append(( samples.append((
@ -383,16 +390,16 @@ def main():
if should_generate: if should_generate:
comparison_inputs.append((text, prompt, language, out_path_comparison)) comparison_inputs.append((text, prompt, language, out_path_comparison))
metrics_inputs.append((text, language, out_path_comparison, prompt, reference, metrics_path)) metrics_inputs.append((dataset_name, text, language, out_path_comparison, prompt, reference, metrics_path))
should_generate = (args.skip_existing and not out_path.exists()) or not (args.skip_existing) should_generate = (args.skip_existing and not out_path.exists()) or not (args.skip_existing)
if should_generate: if should_generate:
inputs.append((text, prompt, language, out_path)) inputs.append((text, prompt, language, out_path))
metrics_inputs.append((text, language, out_path, prompt, reference, metrics_path)) metrics_inputs.append((dataset_name, text, language, out_path, prompt, reference, metrics_path))
outputs.append((k, samples)) outputs.append((dataset_name, samples))
if inputs: if inputs:
process_batch( tts, inputs, sampling_kwargs | (comparison_kwargs["disabled"] if args.comparison else {}) ) process_batch( tts, inputs, sampling_kwargs | (comparison_kwargs["disabled"] if args.comparison else {}) )
@ -401,28 +408,32 @@ def main():
process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) ) process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) )
metrics_map = {} metrics_map = {}
for text, language, out_path, prompt_path, reference_path, metrics_path in tqdm(metrics_inputs, desc="Calculating metrics"): for dataset_name, text, language, out_path, prompt_path, reference_path, metrics_path in tqdm(metrics_inputs, desc="Calculating metrics"):
calculate = not metrics_path.exists() or (metrics_path.stat().st_mtime < out_path.stat().st_mtime) calculate = not metrics_path.exists() or (metrics_path.stat().st_mtime < out_path.stat().st_mtime)
if calculate: if calculate:
wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model ) wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model )
sim_o_score = sim_o( out_path, prompt_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model ) sim_o_score = sim_o( out_path, prompt_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
#sim_o_r_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score} # , "sim-o-r": sim_o_r_score} metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score}
json_write( metrics, metrics_path ) json_write( metrics, metrics_path )
else: else:
metrics = json_read( metrics_path ) metrics = json_read( metrics_path )
wer_score, cer_score, sim_o_score = metrics["wer"], metrics["cer"], metrics["sim-o"] wer_score, cer_score, sim_o_score = metrics["wer"], metrics["cer"], metrics["sim-o"]
metrics_map[out_path] = (wer_score, cer_score, sim_o_score) if dataset_name not in metrics_map:
metrics_map[dataset_name] = {}
metrics_map[dataset_name][out_path] = (wer_score, cer_score, sim_o_score)
# collate entries into HTML # collate entries into HTML
for k, samples in outputs: tables = []
for dataset_name, samples in outputs:
table = "\t\t<h3>${DATASET_NAME}</h3>\n\t\t<p><b>Average WER:</b> ${WER}<br><b>Average CER:</b> ${CER}<br><b>Average SIM-O:</b> ${SIM-O}<br></p>\n\t\t<table>\n\t\t\t<thead>\n\t\t\t\t<tr>\n\t\t\t\t\t<th>Text</th>\n\t\t\t\t\t<th>WER↓</th>\n\t\t\t\t\t<th>CER↓</th>\n\t\t\t\t\t<th>SIM-O↑</th>\n\t\t\t\t\t<th>Prompt</th>\n\t\t\t\t\t<th>Our VALL-E</th>\n\t\t\t\t\t<!--th>Original VALL-E</th-->\n\t\t\t\t\t<!--th>F5-TTS</th-->\n\t\t\t\t\t<th>Ground Truth</th>\n\t\t\t\t</tr>\n\t\t\t</thead>\n\t\t\t<tbody>${SAMPLES}</tbody>\n\t\t</table>"
samples = [ samples = [
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+ f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
"".join([ "".join([
f'\n\t\t\t\t<td>{metrics_map[audios[1]][0]:.3f}</td><td>{metrics_map[audios[1]][1]:.3f}</td><td>{metrics_map[audios[1]][2]:.3f}</td>' f'\n\t\t\t\t<td>{metrics_map[dataset_name][audios[1]][0]:.3f}</td><td>{metrics_map[dataset_name][audios[1]][1]:.3f}</td><td>{metrics_map[dataset_name][audios[1]][2]:.3f}</td>'
] ) + ] ) +
"".join( [ "".join( [
f'\n\t\t\t\t<td><audio controls="controls" preload="none"><source src="{str(audio).replace(str(args.demo_dir), args.audio_path_root) if args.audio_path_root else encode(audio)}"/></audio></td>' f'\n\t\t\t\t<td><audio controls="controls" preload="none"><source src="{str(audio).replace(str(args.demo_dir), args.audio_path_root) if args.audio_path_root else encode(audio)}"/></audio></td>'
@ -433,11 +444,15 @@ def main():
] ]
# write audio into template # write audio into template
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) ) table = table.replace("${WER}", f'{mean([ metrics[0] for metrics in metrics_map[dataset_name].values() ]):.3f}' )
table = table.replace("${CER}", f'{mean([ metrics[1] for metrics in metrics_map[dataset_name].values() ]):.3f}' )
table = table.replace("${SIM-O}", f'{mean([ metrics[2] for metrics in metrics_map[dataset_name].values() ]):.3f}' )
html = html.replace("${WER}", f'{mean([ metrics[0] for metrics in metrics_map.values() ]):.3f}' ) table = table.replace("${DATASET_NAME}", dataset_name)
html = html.replace("${CER}", f'{mean([ metrics[1] for metrics in metrics_map.values() ]):.3f}' ) table = table.replace("${SAMPLES}", "\n".join( samples ) )
html = html.replace("${SIM-O}", f'{mean([ metrics[2] for metrics in metrics_map.values() ]):.3f}' ) tables.append( table )
html = html.replace("${TABLES}", "\n".join( tables ))
if args.comparison: if args.comparison:
disabled, enabled = comparison_kwargs["titles"] disabled, enabled = comparison_kwargs["titles"]

View File

@ -11,6 +11,8 @@ from torch import Tensor
from einops import rearrange from einops import rearrange
from pathlib import Path from pathlib import Path
from tqdm import tqdm, trange
from .emb import g2p, qnt from .emb import g2p, qnt
from .emb.qnt import trim, trim_random, unload_model, repeat_extend_audio from .emb.qnt import trim, trim_random, unload_model, repeat_extend_audio
from .emb.transcribe import transcribe from .emb.transcribe import transcribe
@ -213,7 +215,7 @@ class TTS():
input_prompt_length = sampling_kwargs.pop("input_prompt_length", 0) input_prompt_length = sampling_kwargs.pop("input_prompt_length", 0)
modality = sampling_kwargs.pop("modality", "auto") modality = sampling_kwargs.pop("modality", "auto")
seed = sampling_kwargs.pop("seed", None) seed = sampling_kwargs.pop("seed", None)
tqdm = sampling_kwargs.pop("tqdm", True) use_tqdm = sampling_kwargs.pop("tqdm", True)
use_lora = sampling_kwargs.pop("use_lora", None) use_lora = sampling_kwargs.pop("use_lora", None)
dtype = sampling_kwargs.pop("dtype", self.dtype) dtype = sampling_kwargs.pop("dtype", self.dtype)
amp = sampling_kwargs.pop("amp", self.amp) amp = sampling_kwargs.pop("amp", self.amp)
@ -256,7 +258,7 @@ class TTS():
inputs = [] inputs = []
# tensorfy inputs # tensorfy inputs
for i in range( samples ): for i in trange( samples, desc="Preparing batches" ):
# detect language # detect language
if languages[i] == "auto": if languages[i] == "auto":
languages[i] = g2p.detect_language( texts[i] ) languages[i] = g2p.detect_language( texts[i] )
@ -295,14 +297,14 @@ class TTS():
buffer = ([], [], [], []) buffer = ([], [], [], [])
wavs = [] wavs = []
for texts, proms, langs, out_paths in batches: for texts, proms, langs, out_paths in tqdm(batches, desc="Processing batch"):
seed = set_seed(seed) seed = set_seed(seed)
batch_size = len(texts) batch_size = len(texts)
input_kwargs = dict( input_kwargs = dict(
text_list=texts, text_list=texts,
proms_list=proms, proms_list=proms,
lang_list=langs, lang_list=langs,
disable_tqdm=not tqdm, disable_tqdm=not use_tqdm,
use_lora=use_lora, use_lora=use_lora,
) )
@ -355,7 +357,7 @@ class TTS():
input_prompt_length = sampling_kwargs.pop("input_prompt_length", 0) input_prompt_length = sampling_kwargs.pop("input_prompt_length", 0)
modality = sampling_kwargs.pop("modality", "auto") modality = sampling_kwargs.pop("modality", "auto")
seed = sampling_kwargs.pop("seed", None) seed = sampling_kwargs.pop("seed", None)
tqdm = sampling_kwargs.pop("tqdm", True) use_tqdm = sampling_kwargs.pop("tqdm", True)
use_lora = sampling_kwargs.pop("use_lora", None) use_lora = sampling_kwargs.pop("use_lora", None)
dtype = sampling_kwargs.pop("dtype", self.dtype) dtype = sampling_kwargs.pop("dtype", self.dtype)
amp = sampling_kwargs.pop("amp", self.amp) amp = sampling_kwargs.pop("amp", self.amp)
@ -405,7 +407,7 @@ class TTS():
if model is not None: if model is not None:
text_list = model( text_list = model(
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=["stt"], text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=["stt"],
disable_tqdm=not tqdm, disable_tqdm=not use_tqdm,
use_lora=use_lora, use_lora=use_lora,
**sampling_kwargs, **sampling_kwargs,
) )
@ -452,7 +454,7 @@ class TTS():
text_list=[phns], text_list=[phns],
proms_list=[prom], proms_list=[prom],
lang_list=[lang], lang_list=[lang],
disable_tqdm=not tqdm, disable_tqdm=not use_tqdm,
use_lora=use_lora, use_lora=use_lora,
) )
if model_len is not None: if model_len is not None: