actually do proper wer/cer calculation by un-normalizing the scores
This commit is contained in:
parent
c2c6d912ac
commit
7129582303
|
@ -279,7 +279,7 @@ def main():
|
|||
|
||||
# pull from provided samples
|
||||
samples_dirs = {
|
||||
#"librispeech": args.demo_dir / "librispeech",
|
||||
"librispeech": args.demo_dir / "librispeech",
|
||||
}
|
||||
|
||||
if (args.demo_dir / args.dataset_dir_name).exists():
|
||||
|
@ -407,8 +407,9 @@ def main():
|
|||
if calculate:
|
||||
wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model )
|
||||
sim_o_score = sim_o( out_path, prompt_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
|
||||
#sim_o_r_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
|
||||
|
||||
metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score}
|
||||
metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score} # , "sim-o-r": sim_o_r_score}
|
||||
json_write( metrics, metrics_path )
|
||||
else:
|
||||
metrics = json_read( metrics_path )
|
||||
|
|
|
@ -10,12 +10,7 @@ import torch.nn.functional as F
|
|||
|
||||
from pathlib import Path
|
||||
from torcheval.metrics.functional import word_error_rate
|
||||
|
||||
# cringe warning message
|
||||
try:
|
||||
from torchmetrics.text import CharErrorRate
|
||||
except Exception as e:
|
||||
from torchmetrics import CharErrorRate
|
||||
from torchmetrics.functional.text import char_error_rate
|
||||
|
||||
def wer( audio, reference, language="auto", normalize=True, phonemize=True, **transcription_kwargs ):
|
||||
if language == "auto":
|
||||
|
@ -43,7 +38,13 @@ def wer( audio, reference, language="auto", normalize=True, phonemize=True, **tr
|
|||
reference = encode( reference, language=language )
|
||||
|
||||
wer_score = word_error_rate([transcription], [reference]).item()
|
||||
cer_score = CharErrorRate()([transcription], [reference]).item()
|
||||
# un-normalize
|
||||
wer_score *= len(reference.split())
|
||||
|
||||
cer_score = char_error_rate([transcription], [reference]).item()
|
||||
# un-normalize
|
||||
cer_score *= len(reference)
|
||||
|
||||
return wer_score, cer_score
|
||||
|
||||
def sim_o( audio, reference, **kwargs ):
|
||||
|
|
Loading…
Reference in New Issue
Block a user