tweaks to prompt duration to allow me to divorce how i use it for training with how I'm using it for the demo page, and demo page tweaks to make my life easier
This commit is contained in:
parent
7129582303
commit
ed152f78df
|
@ -5,39 +5,7 @@
|
||||||
<body>
|
<body>
|
||||||
<h1>VALL-E Demo</h1>
|
<h1>VALL-E Demo</h1>
|
||||||
<p>${PREAMBLE}</p>
|
<p>${PREAMBLE}</p>
|
||||||
<table>
|
${TABLES}
|
||||||
<thead>
|
|
||||||
<caption>LibriSpeech</caption>
|
|
||||||
<tr>
|
|
||||||
<th>Text</th>
|
|
||||||
<th>WER↓</th>
|
|
||||||
<th>CER↓</th>
|
|
||||||
<th>SIM-O↑</th>
|
|
||||||
<th>Prompt</th>
|
|
||||||
<th>Our VALL-E</th>
|
|
||||||
<th>Original VALL-E</th>
|
|
||||||
<!--th>F5-TTS</th-->
|
|
||||||
<th>Ground Truth</th>
|
|
||||||
</tr>
|
|
||||||
</thead>
|
|
||||||
<tbody>${LIBRISPEECH_SAMPLES}</tbody>
|
|
||||||
</table>
|
|
||||||
<table>
|
|
||||||
<thead>
|
|
||||||
<caption>Sampled Dataset</caption>
|
|
||||||
<tr>
|
|
||||||
<th>Text</th>
|
|
||||||
<th>WER↓</th>
|
|
||||||
<th>CER↓</th>
|
|
||||||
<th>SIM-O↑</th>
|
|
||||||
<th>Prompt</th>
|
|
||||||
<th>Our VALL-E</th>
|
|
||||||
<!--th>F5-TTS</th-->
|
|
||||||
<th>Ground Truth</th>
|
|
||||||
</tr>
|
|
||||||
</thead>
|
|
||||||
<tbody>${DATASET_SAMPLES}</tbody>
|
|
||||||
</table>
|
|
||||||
<p>Settings used: <pre>${SETTINGS}</pre></p>
|
<p>Settings used: <pre>${SETTINGS}</pre></p>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
|
@ -1058,7 +1058,7 @@ class Dataset(_Dataset):
|
||||||
def sample_prompts(self, spkr_name, reference, should_trim=True):
|
def sample_prompts(self, spkr_name, reference, should_trim=True):
|
||||||
# return no prompt if explicitly requested for who knows why
|
# return no prompt if explicitly requested for who knows why
|
||||||
# or if there's no other speakers to sample from (Emilia has a lot of singleton speakers, but I still want to make use of them)
|
# or if there's no other speakers to sample from (Emilia has a lot of singleton speakers, but I still want to make use of them)
|
||||||
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[-1] == 0 or len(self.paths_by_spkr_name[spkr_name]) <= 1:
|
if len(self.paths_by_spkr_name[spkr_name]) <= 1:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
prom_list = []
|
prom_list = []
|
||||||
|
@ -1075,9 +1075,15 @@ class Dataset(_Dataset):
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[1] <= 0:
|
||||||
|
should_trim = False
|
||||||
|
|
||||||
prom_length = 0
|
prom_length = 0
|
||||||
|
if should_trim:
|
||||||
duration_lo, duration_hi = cfg.dataset.prompt_duration_range
|
duration_lo, duration_hi = cfg.dataset.prompt_duration_range
|
||||||
trim_length = int(random.uniform(duration_lo, duration_hi) * cfg.dataset.frames_per_second) if trim else 0
|
trim_length = int(random.uniform(duration_lo, duration_hi) * cfg.dataset.frames_per_second)
|
||||||
|
else:
|
||||||
|
trim_length = 0
|
||||||
|
|
||||||
for _ in range(cfg.dataset.prompt_max_samples):
|
for _ in range(cfg.dataset.prompt_max_samples):
|
||||||
if reference is not None:
|
if reference is not None:
|
||||||
|
|
|
@ -278,16 +278,24 @@ def main():
|
||||||
html = html.replace(r"${SETTINGS}", str(sampling_kwargs))
|
html = html.replace(r"${SETTINGS}", str(sampling_kwargs))
|
||||||
|
|
||||||
# pull from provided samples
|
# pull from provided samples
|
||||||
samples_dirs = {
|
samples_dirs = {}
|
||||||
"librispeech": args.demo_dir / "librispeech",
|
|
||||||
}
|
|
||||||
|
|
||||||
if (args.demo_dir / args.dataset_dir_name).exists():
|
# only add the existing librispeech validation dataset if i'm doing validation so I can stop commenting this out
|
||||||
|
if "validation" in args.dataset_dir_name:
|
||||||
|
sample_dir["librispeech"] = args.demo_dir / "librispeech",
|
||||||
|
|
||||||
|
# automatically pull from anything under the dataset dir
|
||||||
|
if args.dataset_dir_name.endswith("/*"):
|
||||||
|
args.dataset_dir_name = args.dataset_dir_name[:-2]
|
||||||
|
datasets = [ dir for dir in (args.demo_dir / args.dataset_dir_name).iterdir() if dir.is_dir() ]
|
||||||
|
for path in datasets:
|
||||||
|
samples_dirs[path.name] = path
|
||||||
|
# user provided dataset
|
||||||
|
elif (args.demo_dir / args.dataset_dir_name).exists():
|
||||||
samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
|
samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
|
||||||
|
|
||||||
# pull from dataset samples
|
# pull from dataset samples
|
||||||
if args.sample_from_dataset:
|
if args.sample_from_dataset:
|
||||||
cfg.dataset.cache = False
|
|
||||||
cfg.dataset.sample_type = "path" if len(cfg.dataset.training) < cfg.evaluation.batch_size else "speaker"
|
cfg.dataset.sample_type = "path" if len(cfg.dataset.training) < cfg.evaluation.batch_size else "speaker"
|
||||||
cfg.dataset.sample_order = "random"
|
cfg.dataset.sample_order = "random"
|
||||||
cfg.dataset.tasks_list = [ 'tts' ]
|
cfg.dataset.tasks_list = [ 'tts' ]
|
||||||
|
@ -335,18 +343,17 @@ def main():
|
||||||
outputs = []
|
outputs = []
|
||||||
metrics_inputs = []
|
metrics_inputs = []
|
||||||
comparison_inputs = []
|
comparison_inputs = []
|
||||||
for k, sample_dir in samples_dirs.items():
|
for dataset_name, sample_dir in samples_dirs.items():
|
||||||
if not sample_dir.exists():
|
if not sample_dir.exists():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
samples = []
|
samples = []
|
||||||
speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
|
speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
|
||||||
speakers.sort()
|
speakers.sort()
|
||||||
#sources = [ "ms_valle", "f5" ] if k == "librispeech" else ["f5"]
|
sources = [ "ms_valle", "f5" ] if dataset_name == "librispeech" else []
|
||||||
sources = [ "ms_valle" ] if k == "librispeech" else []
|
|
||||||
|
|
||||||
# generate demo output
|
# generate demo output
|
||||||
for dir in tqdm(speakers, desc=f"Generating demo for {k}"):
|
for dir in tqdm(speakers, desc=f"Preparing demo for {dataset_name}"):
|
||||||
text = open(dir / "prompt.txt", encoding="utf-8").read()
|
text = open(dir / "prompt.txt", encoding="utf-8").read()
|
||||||
language = open(dir / "language.txt").read() if (dir / "language.txt").exists() else "en"
|
language = open(dir / "language.txt").read() if (dir / "language.txt").exists() else "en"
|
||||||
prompt = dir / "prompt.wav"
|
prompt = dir / "prompt.wav"
|
||||||
|
@ -368,7 +375,7 @@ def main():
|
||||||
cmd = f'python inference-cli.py --model "F5-TTS" --ref_audio "{reference}" --gen_text "{text}" --output_dir "{out_path.parent}"'
|
cmd = f'python inference-cli.py --model "F5-TTS" --ref_audio "{reference}" --gen_text "{text}" --output_dir "{out_path.parent}"'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not args.random_prompts or k == "librispeech":
|
if not args.random_prompts or dataset_name == "librispeech":
|
||||||
audio_samples += [ reference ]
|
audio_samples += [ reference ]
|
||||||
|
|
||||||
samples.append((
|
samples.append((
|
||||||
|
@ -383,16 +390,16 @@ def main():
|
||||||
if should_generate:
|
if should_generate:
|
||||||
comparison_inputs.append((text, prompt, language, out_path_comparison))
|
comparison_inputs.append((text, prompt, language, out_path_comparison))
|
||||||
|
|
||||||
metrics_inputs.append((text, language, out_path_comparison, prompt, reference, metrics_path))
|
metrics_inputs.append((dataset_name, text, language, out_path_comparison, prompt, reference, metrics_path))
|
||||||
|
|
||||||
should_generate = (args.skip_existing and not out_path.exists()) or not (args.skip_existing)
|
should_generate = (args.skip_existing and not out_path.exists()) or not (args.skip_existing)
|
||||||
|
|
||||||
if should_generate:
|
if should_generate:
|
||||||
inputs.append((text, prompt, language, out_path))
|
inputs.append((text, prompt, language, out_path))
|
||||||
|
|
||||||
metrics_inputs.append((text, language, out_path, prompt, reference, metrics_path))
|
metrics_inputs.append((dataset_name, text, language, out_path, prompt, reference, metrics_path))
|
||||||
|
|
||||||
outputs.append((k, samples))
|
outputs.append((dataset_name, samples))
|
||||||
|
|
||||||
if inputs:
|
if inputs:
|
||||||
process_batch( tts, inputs, sampling_kwargs | (comparison_kwargs["disabled"] if args.comparison else {}) )
|
process_batch( tts, inputs, sampling_kwargs | (comparison_kwargs["disabled"] if args.comparison else {}) )
|
||||||
|
@ -401,28 +408,32 @@ def main():
|
||||||
process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) )
|
process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) )
|
||||||
|
|
||||||
metrics_map = {}
|
metrics_map = {}
|
||||||
for text, language, out_path, prompt_path, reference_path, metrics_path in tqdm(metrics_inputs, desc="Calculating metrics"):
|
for dataset_name, text, language, out_path, prompt_path, reference_path, metrics_path in tqdm(metrics_inputs, desc="Calculating metrics"):
|
||||||
calculate = not metrics_path.exists() or (metrics_path.stat().st_mtime < out_path.stat().st_mtime)
|
calculate = not metrics_path.exists() or (metrics_path.stat().st_mtime < out_path.stat().st_mtime)
|
||||||
|
|
||||||
if calculate:
|
if calculate:
|
||||||
wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model )
|
wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model )
|
||||||
sim_o_score = sim_o( out_path, prompt_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
|
sim_o_score = sim_o( out_path, prompt_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
|
||||||
#sim_o_r_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
|
|
||||||
|
|
||||||
metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score} # , "sim-o-r": sim_o_r_score}
|
metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score}
|
||||||
json_write( metrics, metrics_path )
|
json_write( metrics, metrics_path )
|
||||||
else:
|
else:
|
||||||
metrics = json_read( metrics_path )
|
metrics = json_read( metrics_path )
|
||||||
wer_score, cer_score, sim_o_score = metrics["wer"], metrics["cer"], metrics["sim-o"]
|
wer_score, cer_score, sim_o_score = metrics["wer"], metrics["cer"], metrics["sim-o"]
|
||||||
|
|
||||||
metrics_map[out_path] = (wer_score, cer_score, sim_o_score)
|
if dataset_name not in metrics_map:
|
||||||
|
metrics_map[dataset_name] = {}
|
||||||
|
|
||||||
|
metrics_map[dataset_name][out_path] = (wer_score, cer_score, sim_o_score)
|
||||||
|
|
||||||
# collate entries into HTML
|
# collate entries into HTML
|
||||||
for k, samples in outputs:
|
tables = []
|
||||||
|
for dataset_name, samples in outputs:
|
||||||
|
table = "\t\t<h3>${DATASET_NAME}</h3>\n\t\t<p><b>Average WER:</b> ${WER}<br><b>Average CER:</b> ${CER}<br><b>Average SIM-O:</b> ${SIM-O}<br></p>\n\t\t<table>\n\t\t\t<thead>\n\t\t\t\t<tr>\n\t\t\t\t\t<th>Text</th>\n\t\t\t\t\t<th>WER↓</th>\n\t\t\t\t\t<th>CER↓</th>\n\t\t\t\t\t<th>SIM-O↑</th>\n\t\t\t\t\t<th>Prompt</th>\n\t\t\t\t\t<th>Our VALL-E</th>\n\t\t\t\t\t<!--th>Original VALL-E</th-->\n\t\t\t\t\t<!--th>F5-TTS</th-->\n\t\t\t\t\t<th>Ground Truth</th>\n\t\t\t\t</tr>\n\t\t\t</thead>\n\t\t\t<tbody>${SAMPLES}</tbody>\n\t\t</table>"
|
||||||
samples = [
|
samples = [
|
||||||
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
|
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
|
||||||
"".join([
|
"".join([
|
||||||
f'\n\t\t\t\t<td>{metrics_map[audios[1]][0]:.3f}</td><td>{metrics_map[audios[1]][1]:.3f}</td><td>{metrics_map[audios[1]][2]:.3f}</td>'
|
f'\n\t\t\t\t<td>{metrics_map[dataset_name][audios[1]][0]:.3f}</td><td>{metrics_map[dataset_name][audios[1]][1]:.3f}</td><td>{metrics_map[dataset_name][audios[1]][2]:.3f}</td>'
|
||||||
] ) +
|
] ) +
|
||||||
"".join( [
|
"".join( [
|
||||||
f'\n\t\t\t\t<td><audio controls="controls" preload="none"><source src="{str(audio).replace(str(args.demo_dir), args.audio_path_root) if args.audio_path_root else encode(audio)}"/></audio></td>'
|
f'\n\t\t\t\t<td><audio controls="controls" preload="none"><source src="{str(audio).replace(str(args.demo_dir), args.audio_path_root) if args.audio_path_root else encode(audio)}"/></audio></td>'
|
||||||
|
@ -433,11 +444,15 @@ def main():
|
||||||
]
|
]
|
||||||
|
|
||||||
# write audio into template
|
# write audio into template
|
||||||
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
|
table = table.replace("${WER}", f'{mean([ metrics[0] for metrics in metrics_map[dataset_name].values() ]):.3f}' )
|
||||||
|
table = table.replace("${CER}", f'{mean([ metrics[1] for metrics in metrics_map[dataset_name].values() ]):.3f}' )
|
||||||
|
table = table.replace("${SIM-O}", f'{mean([ metrics[2] for metrics in metrics_map[dataset_name].values() ]):.3f}' )
|
||||||
|
|
||||||
html = html.replace("${WER}", f'{mean([ metrics[0] for metrics in metrics_map.values() ]):.3f}' )
|
table = table.replace("${DATASET_NAME}", dataset_name)
|
||||||
html = html.replace("${CER}", f'{mean([ metrics[1] for metrics in metrics_map.values() ]):.3f}' )
|
table = table.replace("${SAMPLES}", "\n".join( samples ) )
|
||||||
html = html.replace("${SIM-O}", f'{mean([ metrics[2] for metrics in metrics_map.values() ]):.3f}' )
|
tables.append( table )
|
||||||
|
|
||||||
|
html = html.replace("${TABLES}", "\n".join( tables ))
|
||||||
|
|
||||||
if args.comparison:
|
if args.comparison:
|
||||||
disabled, enabled = comparison_kwargs["titles"]
|
disabled, enabled = comparison_kwargs["titles"]
|
||||||
|
|
|
@ -11,6 +11,8 @@ from torch import Tensor
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from .emb import g2p, qnt
|
from .emb import g2p, qnt
|
||||||
from .emb.qnt import trim, trim_random, unload_model, repeat_extend_audio
|
from .emb.qnt import trim, trim_random, unload_model, repeat_extend_audio
|
||||||
from .emb.transcribe import transcribe
|
from .emb.transcribe import transcribe
|
||||||
|
@ -213,7 +215,7 @@ class TTS():
|
||||||
input_prompt_length = sampling_kwargs.pop("input_prompt_length", 0)
|
input_prompt_length = sampling_kwargs.pop("input_prompt_length", 0)
|
||||||
modality = sampling_kwargs.pop("modality", "auto")
|
modality = sampling_kwargs.pop("modality", "auto")
|
||||||
seed = sampling_kwargs.pop("seed", None)
|
seed = sampling_kwargs.pop("seed", None)
|
||||||
tqdm = sampling_kwargs.pop("tqdm", True)
|
use_tqdm = sampling_kwargs.pop("tqdm", True)
|
||||||
use_lora = sampling_kwargs.pop("use_lora", None)
|
use_lora = sampling_kwargs.pop("use_lora", None)
|
||||||
dtype = sampling_kwargs.pop("dtype", self.dtype)
|
dtype = sampling_kwargs.pop("dtype", self.dtype)
|
||||||
amp = sampling_kwargs.pop("amp", self.amp)
|
amp = sampling_kwargs.pop("amp", self.amp)
|
||||||
|
@ -256,7 +258,7 @@ class TTS():
|
||||||
|
|
||||||
inputs = []
|
inputs = []
|
||||||
# tensorfy inputs
|
# tensorfy inputs
|
||||||
for i in range( samples ):
|
for i in trange( samples, desc="Preparing batches" ):
|
||||||
# detect language
|
# detect language
|
||||||
if languages[i] == "auto":
|
if languages[i] == "auto":
|
||||||
languages[i] = g2p.detect_language( texts[i] )
|
languages[i] = g2p.detect_language( texts[i] )
|
||||||
|
@ -295,14 +297,14 @@ class TTS():
|
||||||
buffer = ([], [], [], [])
|
buffer = ([], [], [], [])
|
||||||
|
|
||||||
wavs = []
|
wavs = []
|
||||||
for texts, proms, langs, out_paths in batches:
|
for texts, proms, langs, out_paths in tqdm(batches, desc="Processing batch"):
|
||||||
seed = set_seed(seed)
|
seed = set_seed(seed)
|
||||||
batch_size = len(texts)
|
batch_size = len(texts)
|
||||||
input_kwargs = dict(
|
input_kwargs = dict(
|
||||||
text_list=texts,
|
text_list=texts,
|
||||||
proms_list=proms,
|
proms_list=proms,
|
||||||
lang_list=langs,
|
lang_list=langs,
|
||||||
disable_tqdm=not tqdm,
|
disable_tqdm=not use_tqdm,
|
||||||
use_lora=use_lora,
|
use_lora=use_lora,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -355,7 +357,7 @@ class TTS():
|
||||||
input_prompt_length = sampling_kwargs.pop("input_prompt_length", 0)
|
input_prompt_length = sampling_kwargs.pop("input_prompt_length", 0)
|
||||||
modality = sampling_kwargs.pop("modality", "auto")
|
modality = sampling_kwargs.pop("modality", "auto")
|
||||||
seed = sampling_kwargs.pop("seed", None)
|
seed = sampling_kwargs.pop("seed", None)
|
||||||
tqdm = sampling_kwargs.pop("tqdm", True)
|
use_tqdm = sampling_kwargs.pop("tqdm", True)
|
||||||
use_lora = sampling_kwargs.pop("use_lora", None)
|
use_lora = sampling_kwargs.pop("use_lora", None)
|
||||||
dtype = sampling_kwargs.pop("dtype", self.dtype)
|
dtype = sampling_kwargs.pop("dtype", self.dtype)
|
||||||
amp = sampling_kwargs.pop("amp", self.amp)
|
amp = sampling_kwargs.pop("amp", self.amp)
|
||||||
|
@ -405,7 +407,7 @@ class TTS():
|
||||||
if model is not None:
|
if model is not None:
|
||||||
text_list = model(
|
text_list = model(
|
||||||
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=["stt"],
|
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=["stt"],
|
||||||
disable_tqdm=not tqdm,
|
disable_tqdm=not use_tqdm,
|
||||||
use_lora=use_lora,
|
use_lora=use_lora,
|
||||||
**sampling_kwargs,
|
**sampling_kwargs,
|
||||||
)
|
)
|
||||||
|
@ -452,7 +454,7 @@ class TTS():
|
||||||
text_list=[phns],
|
text_list=[phns],
|
||||||
proms_list=[prom],
|
proms_list=[prom],
|
||||||
lang_list=[lang],
|
lang_list=[lang],
|
||||||
disable_tqdm=not tqdm,
|
disable_tqdm=not use_tqdm,
|
||||||
use_lora=use_lora,
|
use_lora=use_lora,
|
||||||
)
|
)
|
||||||
if model_len is not None:
|
if model_len is not None:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user