From 7129582303698071cdb1feb79d8399606d5f0a7c Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 17 Dec 2024 14:22:30 -0600 Subject: [PATCH] actually do proper wer/cer calculation by un-normalizing the scores --- vall_e/demo.py | 5 +++-- vall_e/metrics.py | 15 ++++++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/vall_e/demo.py b/vall_e/demo.py index 8da990b..96f9551 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -279,7 +279,7 @@ def main(): # pull from provided samples samples_dirs = { - #"librispeech": args.demo_dir / "librispeech", + "librispeech": args.demo_dir / "librispeech", } if (args.demo_dir / args.dataset_dir_name).exists(): @@ -407,8 +407,9 @@ def main(): if calculate: wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model ) sim_o_score = sim_o( out_path, prompt_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model ) + #sim_o_r_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model ) - metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score} + metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score} # , "sim-o-r": sim_o_r_score} json_write( metrics, metrics_path ) else: metrics = json_read( metrics_path ) diff --git a/vall_e/metrics.py b/vall_e/metrics.py index c38107d..8e91e0f 100644 --- a/vall_e/metrics.py +++ b/vall_e/metrics.py @@ -10,12 +10,7 @@ import torch.nn.functional as F from pathlib import Path from torcheval.metrics.functional import word_error_rate - -# cringe warning message -try: - from torchmetrics.text import CharErrorRate -except Exception as e: - from torchmetrics import CharErrorRate +from torchmetrics.functional.text import char_error_rate def wer( audio, reference, language="auto", normalize=True, phonemize=True, **transcription_kwargs ): if language == "auto": @@ -43,7 +38,13 @@ def wer( audio, reference, language="auto", normalize=True, phonemize=True, **tr reference = encode( reference, language=language ) wer_score = word_error_rate([transcription], [reference]).item() - cer_score = CharErrorRate()([transcription], [reference]).item() + # un-normalize + wer_score *= len(reference.split()) + + cer_score = char_error_rate([transcription], [reference]).item() + # un-normalize + cer_score *= len(reference) + return wer_score, cer_score def sim_o( audio, reference, **kwargs ):