diff --git a/vall_e/demo.py b/vall_e/demo.py index a050102..13d4160 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -35,6 +35,9 @@ from .utils import setup_logging from tqdm import tqdm, trange +def mean( l ): + return sum(l) / len(l) + def encode(path): if path is None or not path.exists(): return "" @@ -132,6 +135,9 @@ def main(): parser.add_argument("--lora", action="store_true") parser.add_argument("--comparison", type=str, default=None) + parser.add_argument("--transcription-model", type=str, default="base") + parser.add_argument("--speaker-similarity-model", type=str, default="wavlm_base_plus") + args = parser.parse_args() config = None @@ -149,6 +155,10 @@ def main(): args.preamble = "
".join([ 'Below are some samples from my VALL-E implementation: https://git.ecker.tech/mrq/vall-e/.', 'Unlike the original VALL-E demo page, I\'m placing emphasis on the input prompt, as the model adheres to it stronger than others.', + f'Objective metrics are computed by transcribing ({args.transcription_model}) then comparing the word error rate on transcriptions (WER/CER), and computing the cosine similarities on embeddings through a speaker feature extraction model ({args.speaker_similarity_model}) (SIM-O)', + 'Total WER: ${WER}' + 'Total CER: ${CER}' + 'Total SIM-O: ${SIM-O}' ]) # comparison kwargs @@ -384,17 +394,18 @@ def main(): process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) ) metrics_map = {} - for text, language, out_path, reference_path in metrics_inputs: - wer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name="base" ) - sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype ) - metrics_map[out_path] = (wer_score, sim_o_score) + total_metrics = (0, 0) + for text, language, out_path, reference_path in tqdm(metrics_inputs, desc="Calculating metrics"): + wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model ) + sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, feat_type=args.speaker_similarity_model ) + metrics_map[out_path] = (wer_score, cer_score, sim_o_score) # collate entries into HTML for k, samples in outputs: samples = [ f'\n\t\t\t\n\t\t\t\t{text}'+ "".join([ - f'\n\t\t\t\t{metrics_map[audios[1]][0]:.3f}{metrics_map[audios[1]][1]:.3f}' + f'\n\t\t\t\t{metrics_map[audios[1]][0]:.3f}{metrics_map[audios[1]][1]:.3f}{metrics_map[audios[1]][2]:.3f}' ] ) + "".join( [ f'\n\t\t\t\t' @@ -406,6 +417,10 @@ def main(): # write audio into template html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) ) + + html = html.replace("${WER}", f'{mean([ metrics[0] for metrics in metrics_map.values() ]):.3f}' ) + html = html.replace("${CER}", f'{mean([ metrics[1] for metrics in metrics_map.values() ]):.3f}' ) + html = html.replace("${SIM-O}", f'{mean([ metrics[2] for metrics in metrics_map.values() ]):.3f}' ) if args.comparison: disabled, enabled = comparison_kwargs["titles"] diff --git a/vall_e/emb/similar.py b/vall_e/emb/similar.py index 42b16db..8648770 100644 --- a/vall_e/emb/similar.py +++ b/vall_e/emb/similar.py @@ -46,8 +46,24 @@ tts = None # this is for computing SIM-O, but can probably technically be used for scoring similar utterances @cache -def _load_sim_model(device="cuda", dtype="float16", feat_type="wavlm_base_plus", feat_dim=768): +def _load_sim_model(device="cuda", dtype="float16", feat_type="wavlm_base_plus", feat_dim="auto"): + logging.getLogger('s3prl').setLevel(logging.DEBUG) + logging.getLogger('speechbrain').setLevel(logging.DEBUG) + from ..utils.ext.ecapa_tdnn import ECAPA_TDNN_SMALL + + if feat_dim == "auto": + if feat_type == "fbank": + feat_dim = 40 + elif feat_type == "wavlm_base_plus": + feat_dim = 768 + elif feat_type == "wavlm_large": + feat_dim = 1024 + elif feat_type == "hubert_large_ll60k": + feat_dim = 1024 + elif feat_type == "wav2vec2_xlsr": + feat_dim = 1024 + model = ECAPA_TDNN_SMALL(feat_dim=feat_dim, feat_type=feat_type, config_path=None) model = model.to(device=device, dtype=coerce_dtype(dtype)) model = model.eval() diff --git a/vall_e/metrics.py b/vall_e/metrics.py index 67388b8..032e38b 100644 --- a/vall_e/metrics.py +++ b/vall_e/metrics.py @@ -3,28 +3,43 @@ #from .emb.transcribe import transcribe from .emb.similar import speaker_similarity_embedding from .emb.transcribe import transcribe -from .emb.g2p import detect_language +from .emb.g2p import detect_language, coerce_to_hiragana, encode from .data import normalize_text import torch.nn.functional as F from pathlib import Path from torcheval.metrics.functional import word_error_rate +from torchmetrics import CharErrorRate -def wer( audio, reference, language="auto", **transcription_kwargs ): +def wer( audio, reference, language="auto", normalize=True, phonemize=True, **transcription_kwargs ): if language == "auto": language = detect_language( reference ) - transcription = transcribe( audio, language=language, align=False, **transcription_kwargs )["text"] + transcription = transcribe( audio, language=language, align=False, **transcription_kwargs ) + if language == "auto": + language = transcription["language"] + transcription = transcription["text"] # reference audio needs transcribing too if isinstance( reference, Path ): reference = transcribe( reference, language=language, align=False, **transcription_kwargs )["text"] - transcription = normalize_text( transcription ) - reference = normalize_text( reference ) + if language == "ja": + transcription = coerce_to_hiragana( transcription ) + reference = coerce_to_hiragana( reference ) - return word_error_rate([transcription], [reference]).item() + if normalize: + transcription = normalize_text( transcription ) + reference = normalize_text( reference ) + + if phonemize: + transcription = encode( transcription, language=language ) + reference = encode( reference, language=language ) + + wer_score = word_error_rate([transcription], [reference]).item() + cer_score = CharErrorRate()([transcription], [reference]).item() + return wer_score, cer_score def sim_o( audio, reference, **kwargs ): audio_emb = speaker_similarity_embedding( audio, **kwargs ) diff --git a/vall_e/utils/ext/ecapa_tdnn.py b/vall_e/utils/ext/ecapa_tdnn.py index 29a99fb..8f28f69 100644 --- a/vall_e/utils/ext/ecapa_tdnn.py +++ b/vall_e/utils/ext/ecapa_tdnn.py @@ -1,4 +1,5 @@ # borrowed with love from "https://github.com/keonlee9420/evaluate-zero-shot-tts/blob/master/src/evaluate_zero_shot_tts/utils/speaker_verification/models/ecapa_tdnn.py" +# (which was from https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/ecapa_tdnn.py) # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN import torch