actually do proper wer/cer calculation by un-normalizing the scores

This commit is contained in:
mrq 2024-12-17 14:22:30 -06:00
parent c2c6d912ac
commit 7129582303
2 changed files with 11 additions and 9 deletions

View File

@ -279,7 +279,7 @@ def main():
# pull from provided samples # pull from provided samples
samples_dirs = { samples_dirs = {
#"librispeech": args.demo_dir / "librispeech", "librispeech": args.demo_dir / "librispeech",
} }
if (args.demo_dir / args.dataset_dir_name).exists(): if (args.demo_dir / args.dataset_dir_name).exists():
@ -407,8 +407,9 @@ def main():
if calculate: if calculate:
wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model ) wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model )
sim_o_score = sim_o( out_path, prompt_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model ) sim_o_score = sim_o( out_path, prompt_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
#sim_o_r_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score} metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score} # , "sim-o-r": sim_o_r_score}
json_write( metrics, metrics_path ) json_write( metrics, metrics_path )
else: else:
metrics = json_read( metrics_path ) metrics = json_read( metrics_path )

View File

@ -10,12 +10,7 @@ import torch.nn.functional as F
from pathlib import Path from pathlib import Path
from torcheval.metrics.functional import word_error_rate from torcheval.metrics.functional import word_error_rate
from torchmetrics.functional.text import char_error_rate
# cringe warning message
try:
from torchmetrics.text import CharErrorRate
except Exception as e:
from torchmetrics import CharErrorRate
def wer( audio, reference, language="auto", normalize=True, phonemize=True, **transcription_kwargs ): def wer( audio, reference, language="auto", normalize=True, phonemize=True, **transcription_kwargs ):
if language == "auto": if language == "auto":
@ -43,7 +38,13 @@ def wer( audio, reference, language="auto", normalize=True, phonemize=True, **tr
reference = encode( reference, language=language ) reference = encode( reference, language=language )
wer_score = word_error_rate([transcription], [reference]).item() wer_score = word_error_rate([transcription], [reference]).item()
cer_score = CharErrorRate()([transcription], [reference]).item() # un-normalize
wer_score *= len(reference.split())
cer_score = char_error_rate([transcription], [reference]).item()
# un-normalize
cer_score *= len(reference)
return wer_score, cer_score return wer_score, cer_score
def sim_o( audio, reference, **kwargs ): def sim_o( audio, reference, **kwargs ):