Added CER, transcription/similarity model args in demo

This commit is contained in:
mrq 2024-12-10 21:00:51 -06:00
parent 8568a93dad
commit 6f1ee0c6fa
4 changed files with 59 additions and 12 deletions

View File

@ -35,6 +35,9 @@ from .utils import setup_logging
from tqdm import tqdm, trange
def mean( l ):
return sum(l) / len(l)
def encode(path):
if path is None or not path.exists():
return ""
@ -132,6 +135,9 @@ def main():
parser.add_argument("--lora", action="store_true")
parser.add_argument("--comparison", type=str, default=None)
parser.add_argument("--transcription-model", type=str, default="base")
parser.add_argument("--speaker-similarity-model", type=str, default="wavlm_base_plus")
args = parser.parse_args()
config = None
@ -149,6 +155,10 @@ def main():
args.preamble = "<br>".join([
'Below are some samples from my VALL-E implementation: <a href="https://git.ecker.tech/mrq/vall-e/">https://git.ecker.tech/mrq/vall-e/</a>.',
'Unlike the original VALL-E demo page, I\'m placing emphasis on the input prompt, as the model adheres to it stronger than others.',
f'Objective metrics are computed by transcribing ({args.transcription_model}) then comparing the word error rate on transcriptions (WER/CER), and computing the cosine similarities on embeddings through a speaker feature extraction model ({args.speaker_similarity_model}) (SIM-O)',
'<b>Total WER:</b> ${WER}'
'<b>Total CER:</b> ${CER}'
'<b>Total SIM-O:</b> ${SIM-O}'
])
# comparison kwargs
@ -384,17 +394,18 @@ def main():
process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) )
metrics_map = {}
for text, language, out_path, reference_path in metrics_inputs:
wer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name="base" )
sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype )
metrics_map[out_path] = (wer_score, sim_o_score)
total_metrics = (0, 0)
for text, language, out_path, reference_path in tqdm(metrics_inputs, desc="Calculating metrics"):
wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model )
sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, feat_type=args.speaker_similarity_model )
metrics_map[out_path] = (wer_score, cer_score, sim_o_score)
# collate entries into HTML
for k, samples in outputs:
samples = [
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
"".join([
f'\n\t\t\t\t<td>{metrics_map[audios[1]][0]:.3f}</td><td>{metrics_map[audios[1]][1]:.3f}</td>'
f'\n\t\t\t\t<td>{metrics_map[audios[1]][0]:.3f}</td><td>{metrics_map[audios[1]][1]:.3f}</td><td>{metrics_map[audios[1]][2]:.3f}</td>'
] ) +
"".join( [
f'\n\t\t\t\t<td><audio controls="controls" preload="none"><source src="{str(audio).replace(str(args.demo_dir), args.audio_path_root) if args.audio_path_root else encode(audio)}"/></audio></td>'
@ -406,6 +417,10 @@ def main():
# write audio into template
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
html = html.replace("${WER}", f'{mean([ metrics[0] for metrics in metrics_map.values() ]):.3f}' )
html = html.replace("${CER}", f'{mean([ metrics[1] for metrics in metrics_map.values() ]):.3f}' )
html = html.replace("${SIM-O}", f'{mean([ metrics[2] for metrics in metrics_map.values() ]):.3f}' )
if args.comparison:
disabled, enabled = comparison_kwargs["titles"]

View File

@ -46,8 +46,24 @@ tts = None
# this is for computing SIM-O, but can probably technically be used for scoring similar utterances
@cache
def _load_sim_model(device="cuda", dtype="float16", feat_type="wavlm_base_plus", feat_dim=768):
def _load_sim_model(device="cuda", dtype="float16", feat_type="wavlm_base_plus", feat_dim="auto"):
logging.getLogger('s3prl').setLevel(logging.DEBUG)
logging.getLogger('speechbrain').setLevel(logging.DEBUG)
from ..utils.ext.ecapa_tdnn import ECAPA_TDNN_SMALL
if feat_dim == "auto":
if feat_type == "fbank":
feat_dim = 40
elif feat_type == "wavlm_base_plus":
feat_dim = 768
elif feat_type == "wavlm_large":
feat_dim = 1024
elif feat_type == "hubert_large_ll60k":
feat_dim = 1024
elif feat_type == "wav2vec2_xlsr":
feat_dim = 1024
model = ECAPA_TDNN_SMALL(feat_dim=feat_dim, feat_type=feat_type, config_path=None)
model = model.to(device=device, dtype=coerce_dtype(dtype))
model = model.eval()

View File

@ -3,28 +3,43 @@
#from .emb.transcribe import transcribe
from .emb.similar import speaker_similarity_embedding
from .emb.transcribe import transcribe
from .emb.g2p import detect_language
from .emb.g2p import detect_language, coerce_to_hiragana, encode
from .data import normalize_text
import torch.nn.functional as F
from pathlib import Path
from torcheval.metrics.functional import word_error_rate
from torchmetrics import CharErrorRate
def wer( audio, reference, language="auto", **transcription_kwargs ):
def wer( audio, reference, language="auto", normalize=True, phonemize=True, **transcription_kwargs ):
if language == "auto":
language = detect_language( reference )
transcription = transcribe( audio, language=language, align=False, **transcription_kwargs )["text"]
transcription = transcribe( audio, language=language, align=False, **transcription_kwargs )
if language == "auto":
language = transcription["language"]
transcription = transcription["text"]
# reference audio needs transcribing too
if isinstance( reference, Path ):
reference = transcribe( reference, language=language, align=False, **transcription_kwargs )["text"]
transcription = normalize_text( transcription )
reference = normalize_text( reference )
if language == "ja":
transcription = coerce_to_hiragana( transcription )
reference = coerce_to_hiragana( reference )
return word_error_rate([transcription], [reference]).item()
if normalize:
transcription = normalize_text( transcription )
reference = normalize_text( reference )
if phonemize:
transcription = encode( transcription, language=language )
reference = encode( reference, language=language )
wer_score = word_error_rate([transcription], [reference]).item()
cer_score = CharErrorRate()([transcription], [reference]).item()
return wer_score, cer_score
def sim_o( audio, reference, **kwargs ):
audio_emb = speaker_similarity_embedding( audio, **kwargs )

View File

@ -1,4 +1,5 @@
# borrowed with love from "https://github.com/keonlee9420/evaluate-zero-shot-tts/blob/master/src/evaluate_zero_shot_tts/utils/speaker_verification/models/ecapa_tdnn.py"
# (which was from https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/ecapa_tdnn.py)
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
import torch