tweaks to prompt duration to allow me to divorce how i use it for training with how I'm using it for the demo page, and demo page tweaks to make my life easier
This commit is contained in:
parent
7129582303
commit
ed152f78df
|
@ -5,39 +5,7 @@
|
|||
<body>
|
||||
<h1>VALL-E Demo</h1>
|
||||
<p>${PREAMBLE}</p>
|
||||
<table>
|
||||
<thead>
|
||||
<caption>LibriSpeech</caption>
|
||||
<tr>
|
||||
<th>Text</th>
|
||||
<th>WER↓</th>
|
||||
<th>CER↓</th>
|
||||
<th>SIM-O↑</th>
|
||||
<th>Prompt</th>
|
||||
<th>Our VALL-E</th>
|
||||
<th>Original VALL-E</th>
|
||||
<!--th>F5-TTS</th-->
|
||||
<th>Ground Truth</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>${LIBRISPEECH_SAMPLES}</tbody>
|
||||
</table>
|
||||
<table>
|
||||
<thead>
|
||||
<caption>Sampled Dataset</caption>
|
||||
<tr>
|
||||
<th>Text</th>
|
||||
<th>WER↓</th>
|
||||
<th>CER↓</th>
|
||||
<th>SIM-O↑</th>
|
||||
<th>Prompt</th>
|
||||
<th>Our VALL-E</th>
|
||||
<!--th>F5-TTS</th-->
|
||||
<th>Ground Truth</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>${DATASET_SAMPLES}</tbody>
|
||||
</table>
|
||||
${TABLES}
|
||||
<p>Settings used: <pre>${SETTINGS}</pre></p>
|
||||
</body>
|
||||
</html>
|
|
@ -1058,7 +1058,7 @@ class Dataset(_Dataset):
|
|||
def sample_prompts(self, spkr_name, reference, should_trim=True):
|
||||
# return no prompt if explicitly requested for who knows why
|
||||
# or if there's no other speakers to sample from (Emilia has a lot of singleton speakers, but I still want to make use of them)
|
||||
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0 or len(self.paths_by_spkr_name[spkr_name]) <= 1:
|
||||
if len(self.paths_by_spkr_name[spkr_name]) <= 1:
|
||||
return None
|
||||
|
||||
prom_list = []
|
||||
|
@ -1075,9 +1075,15 @@ class Dataset(_Dataset):
|
|||
)
|
||||
"""
|
||||
|
||||
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[1] <= 0:
|
||||
should_trim = False
|
||||
|
||||
prom_length = 0
|
||||
duration_lo, duration_hi = cfg.dataset.prompt_duration_range
|
||||
trim_length = int(random.uniform(duration_lo, duration_hi) * cfg.dataset.frames_per_second) if trim else 0
|
||||
if should_trim:
|
||||
duration_lo, duration_hi = cfg.dataset.prompt_duration_range
|
||||
trim_length = int(random.uniform(duration_lo, duration_hi) * cfg.dataset.frames_per_second)
|
||||
else:
|
||||
trim_length = 0
|
||||
|
||||
for _ in range(cfg.dataset.prompt_max_samples):
|
||||
if reference is not None:
|
||||
|
|
|
@ -278,16 +278,24 @@ def main():
|
|||
html = html.replace(r"${SETTINGS}", str(sampling_kwargs))
|
||||
|
||||
# pull from provided samples
|
||||
samples_dirs = {
|
||||
"librispeech": args.demo_dir / "librispeech",
|
||||
}
|
||||
samples_dirs = {}
|
||||
|
||||
if (args.demo_dir / args.dataset_dir_name).exists():
|
||||
# only add the existing librispeech validation dataset if i'm doing validation so I can stop commenting this out
|
||||
if "validation" in args.dataset_dir_name:
|
||||
sample_dir["librispeech"] = args.demo_dir / "librispeech",
|
||||
|
||||
# automatically pull from anything under the dataset dir
|
||||
if args.dataset_dir_name.endswith("/*"):
|
||||
args.dataset_dir_name = args.dataset_dir_name[:-2]
|
||||
datasets = [ dir for dir in (args.demo_dir / args.dataset_dir_name).iterdir() if dir.is_dir() ]
|
||||
for path in datasets:
|
||||
samples_dirs[path.name] = path
|
||||
# user provided dataset
|
||||
elif (args.demo_dir / args.dataset_dir_name).exists():
|
||||
samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
|
||||
|
||||
# pull from dataset samples
|
||||
if args.sample_from_dataset:
|
||||
cfg.dataset.cache = False
|
||||
cfg.dataset.sample_type = "path" if len(cfg.dataset.training) < cfg.evaluation.batch_size else "speaker"
|
||||
cfg.dataset.sample_order = "random"
|
||||
cfg.dataset.tasks_list = [ 'tts' ]
|
||||
|
@ -335,18 +343,17 @@ def main():
|
|||
outputs = []
|
||||
metrics_inputs = []
|
||||
comparison_inputs = []
|
||||
for k, sample_dir in samples_dirs.items():
|
||||
for dataset_name, sample_dir in samples_dirs.items():
|
||||
if not sample_dir.exists():
|
||||
continue
|
||||
|
||||
samples = []
|
||||
speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
|
||||
speakers.sort()
|
||||
#sources = [ "ms_valle", "f5" ] if k == "librispeech" else ["f5"]
|
||||
sources = [ "ms_valle" ] if k == "librispeech" else []
|
||||
sources = [ "ms_valle", "f5" ] if dataset_name == "librispeech" else []
|
||||
|
||||
# generate demo output
|
||||
for dir in tqdm(speakers, desc=f"Generating demo for {k}"):
|
||||
for dir in tqdm(speakers, desc=f"Preparing demo for {dataset_name}"):
|
||||
text = open(dir / "prompt.txt", encoding="utf-8").read()
|
||||
language = open(dir / "language.txt").read() if (dir / "language.txt").exists() else "en"
|
||||
prompt = dir / "prompt.wav"
|
||||
|
@ -368,7 +375,7 @@ def main():
|
|||
cmd = f'python inference-cli.py --model "F5-TTS" --ref_audio "{reference}" --gen_text "{text}" --output_dir "{out_path.parent}"'
|
||||
"""
|
||||
|
||||
if not args.random_prompts or k == "librispeech":
|
||||
if not args.random_prompts or dataset_name == "librispeech":
|
||||
audio_samples += [ reference ]
|
||||
|
||||
samples.append((
|
||||
|
@ -383,16 +390,16 @@ def main():
|
|||
if should_generate:
|
||||
comparison_inputs.append((text, prompt, language, out_path_comparison))
|
||||
|
||||
metrics_inputs.append((text, language, out_path_comparison, prompt, reference, metrics_path))
|
||||
metrics_inputs.append((dataset_name, text, language, out_path_comparison, prompt, reference, metrics_path))
|
||||
|
||||
should_generate = (args.skip_existing and not out_path.exists()) or not (args.skip_existing)
|
||||
|
||||
if should_generate:
|
||||
inputs.append((text, prompt, language, out_path))
|
||||
|
||||
metrics_inputs.append((text, language, out_path, prompt, reference, metrics_path))
|
||||
metrics_inputs.append((dataset_name, text, language, out_path, prompt, reference, metrics_path))
|
||||
|
||||
outputs.append((k, samples))
|
||||
outputs.append((dataset_name, samples))
|
||||
|
||||
if inputs:
|
||||
process_batch( tts, inputs, sampling_kwargs | (comparison_kwargs["disabled"] if args.comparison else {}) )
|
||||
|
@ -401,28 +408,32 @@ def main():
|
|||
process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) )
|
||||
|
||||
metrics_map = {}
|
||||
for text, language, out_path, prompt_path, reference_path, metrics_path in tqdm(metrics_inputs, desc="Calculating metrics"):
|
||||
for dataset_name, text, language, out_path, prompt_path, reference_path, metrics_path in tqdm(metrics_inputs, desc="Calculating metrics"):
|
||||
calculate = not metrics_path.exists() or (metrics_path.stat().st_mtime < out_path.stat().st_mtime)
|
||||
|
||||
if calculate:
|
||||
wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model )
|
||||
sim_o_score = sim_o( out_path, prompt_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
|
||||
#sim_o_r_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
|
||||
|
||||
metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score} # , "sim-o-r": sim_o_r_score}
|
||||
metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score}
|
||||
json_write( metrics, metrics_path )
|
||||
else:
|
||||
metrics = json_read( metrics_path )
|
||||
wer_score, cer_score, sim_o_score = metrics["wer"], metrics["cer"], metrics["sim-o"]
|
||||
|
||||
metrics_map[out_path] = (wer_score, cer_score, sim_o_score)
|
||||
if dataset_name not in metrics_map:
|
||||
metrics_map[dataset_name] = {}
|
||||
|
||||
metrics_map[dataset_name][out_path] = (wer_score, cer_score, sim_o_score)
|
||||
|
||||
# collate entries into HTML
|
||||
for k, samples in outputs:
|
||||
tables = []
|
||||
for dataset_name, samples in outputs:
|
||||
table = "\t\t<h3>${DATASET_NAME}</h3>\n\t\t<p><b>Average WER:</b> ${WER}<br><b>Average CER:</b> ${CER}<br><b>Average SIM-O:</b> ${SIM-O}<br></p>\n\t\t<table>\n\t\t\t<thead>\n\t\t\t\t<tr>\n\t\t\t\t\t<th>Text</th>\n\t\t\t\t\t<th>WER↓</th>\n\t\t\t\t\t<th>CER↓</th>\n\t\t\t\t\t<th>SIM-O↑</th>\n\t\t\t\t\t<th>Prompt</th>\n\t\t\t\t\t<th>Our VALL-E</th>\n\t\t\t\t\t<!--th>Original VALL-E</th-->\n\t\t\t\t\t<!--th>F5-TTS</th-->\n\t\t\t\t\t<th>Ground Truth</th>\n\t\t\t\t</tr>\n\t\t\t</thead>\n\t\t\t<tbody>${SAMPLES}</tbody>\n\t\t</table>"
|
||||
samples = [
|
||||
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
|
||||
"".join([
|
||||
f'\n\t\t\t\t<td>{metrics_map[audios[1]][0]:.3f}</td><td>{metrics_map[audios[1]][1]:.3f}</td><td>{metrics_map[audios[1]][2]:.3f}</td>'
|
||||
f'\n\t\t\t\t<td>{metrics_map[dataset_name][audios[1]][0]:.3f}</td><td>{metrics_map[dataset_name][audios[1]][1]:.3f}</td><td>{metrics_map[dataset_name][audios[1]][2]:.3f}</td>'
|
||||
] ) +
|
||||
"".join( [
|
||||
f'\n\t\t\t\t<td><audio controls="controls" preload="none"><source src="{str(audio).replace(str(args.demo_dir), args.audio_path_root) if args.audio_path_root else encode(audio)}"/></audio></td>'
|
||||
|
@ -433,11 +444,15 @@ def main():
|
|||
]
|
||||
|
||||
# write audio into template
|
||||
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
|
||||
table = table.replace("${WER}", f'{mean([ metrics[0] for metrics in metrics_map[dataset_name].values() ]):.3f}' )
|
||||
table = table.replace("${CER}", f'{mean([ metrics[1] for metrics in metrics_map[dataset_name].values() ]):.3f}' )
|
||||
table = table.replace("${SIM-O}", f'{mean([ metrics[2] for metrics in metrics_map[dataset_name].values() ]):.3f}' )
|
||||
|
||||
table = table.replace("${DATASET_NAME}", dataset_name)
|
||||
table = table.replace("${SAMPLES}", "\n".join( samples ) )
|
||||
tables.append( table )
|
||||
|
||||
html = html.replace("${WER}", f'{mean([ metrics[0] for metrics in metrics_map.values() ]):.3f}' )
|
||||
html = html.replace("${CER}", f'{mean([ metrics[1] for metrics in metrics_map.values() ]):.3f}' )
|
||||
html = html.replace("${SIM-O}", f'{mean([ metrics[2] for metrics in metrics_map.values() ]):.3f}' )
|
||||
html = html.replace("${TABLES}", "\n".join( tables ))
|
||||
|
||||
if args.comparison:
|
||||
disabled, enabled = comparison_kwargs["titles"]
|
||||
|
|
|
@ -11,6 +11,8 @@ from torch import Tensor
|
|||
from einops import rearrange
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from .emb import g2p, qnt
|
||||
from .emb.qnt import trim, trim_random, unload_model, repeat_extend_audio
|
||||
from .emb.transcribe import transcribe
|
||||
|
@ -213,7 +215,7 @@ class TTS():
|
|||
input_prompt_length = sampling_kwargs.pop("input_prompt_length", 0)
|
||||
modality = sampling_kwargs.pop("modality", "auto")
|
||||
seed = sampling_kwargs.pop("seed", None)
|
||||
tqdm = sampling_kwargs.pop("tqdm", True)
|
||||
use_tqdm = sampling_kwargs.pop("tqdm", True)
|
||||
use_lora = sampling_kwargs.pop("use_lora", None)
|
||||
dtype = sampling_kwargs.pop("dtype", self.dtype)
|
||||
amp = sampling_kwargs.pop("amp", self.amp)
|
||||
|
@ -256,7 +258,7 @@ class TTS():
|
|||
|
||||
inputs = []
|
||||
# tensorfy inputs
|
||||
for i in range( samples ):
|
||||
for i in trange( samples, desc="Preparing batches" ):
|
||||
# detect language
|
||||
if languages[i] == "auto":
|
||||
languages[i] = g2p.detect_language( texts[i] )
|
||||
|
@ -295,14 +297,14 @@ class TTS():
|
|||
buffer = ([], [], [], [])
|
||||
|
||||
wavs = []
|
||||
for texts, proms, langs, out_paths in batches:
|
||||
for texts, proms, langs, out_paths in tqdm(batches, desc="Processing batch"):
|
||||
seed = set_seed(seed)
|
||||
batch_size = len(texts)
|
||||
input_kwargs = dict(
|
||||
text_list=texts,
|
||||
proms_list=proms,
|
||||
lang_list=langs,
|
||||
disable_tqdm=not tqdm,
|
||||
disable_tqdm=not use_tqdm,
|
||||
use_lora=use_lora,
|
||||
)
|
||||
|
||||
|
@ -355,7 +357,7 @@ class TTS():
|
|||
input_prompt_length = sampling_kwargs.pop("input_prompt_length", 0)
|
||||
modality = sampling_kwargs.pop("modality", "auto")
|
||||
seed = sampling_kwargs.pop("seed", None)
|
||||
tqdm = sampling_kwargs.pop("tqdm", True)
|
||||
use_tqdm = sampling_kwargs.pop("tqdm", True)
|
||||
use_lora = sampling_kwargs.pop("use_lora", None)
|
||||
dtype = sampling_kwargs.pop("dtype", self.dtype)
|
||||
amp = sampling_kwargs.pop("amp", self.amp)
|
||||
|
@ -405,7 +407,7 @@ class TTS():
|
|||
if model is not None:
|
||||
text_list = model(
|
||||
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=["stt"],
|
||||
disable_tqdm=not tqdm,
|
||||
disable_tqdm=not use_tqdm,
|
||||
use_lora=use_lora,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
|
@ -452,7 +454,7 @@ class TTS():
|
|||
text_list=[phns],
|
||||
proms_list=[prom],
|
||||
lang_list=[lang],
|
||||
disable_tqdm=not tqdm,
|
||||
disable_tqdm=not use_tqdm,
|
||||
use_lora=use_lora,
|
||||
)
|
||||
if model_len is not None:
|
||||
|
|
Loading…
Reference in New Issue
Block a user