actually do proper wer/cer calculation by un-normalizing the scores
This commit is contained in:
parent
c2c6d912ac
commit
7129582303
|
@ -279,7 +279,7 @@ def main():
|
||||||
|
|
||||||
# pull from provided samples
|
# pull from provided samples
|
||||||
samples_dirs = {
|
samples_dirs = {
|
||||||
#"librispeech": args.demo_dir / "librispeech",
|
"librispeech": args.demo_dir / "librispeech",
|
||||||
}
|
}
|
||||||
|
|
||||||
if (args.demo_dir / args.dataset_dir_name).exists():
|
if (args.demo_dir / args.dataset_dir_name).exists():
|
||||||
|
@ -407,8 +407,9 @@ def main():
|
||||||
if calculate:
|
if calculate:
|
||||||
wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model )
|
wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model )
|
||||||
sim_o_score = sim_o( out_path, prompt_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
|
sim_o_score = sim_o( out_path, prompt_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
|
||||||
|
#sim_o_r_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
|
||||||
|
|
||||||
metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score}
|
metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score} # , "sim-o-r": sim_o_r_score}
|
||||||
json_write( metrics, metrics_path )
|
json_write( metrics, metrics_path )
|
||||||
else:
|
else:
|
||||||
metrics = json_read( metrics_path )
|
metrics = json_read( metrics_path )
|
||||||
|
|
|
@ -10,12 +10,7 @@ import torch.nn.functional as F
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from torcheval.metrics.functional import word_error_rate
|
from torcheval.metrics.functional import word_error_rate
|
||||||
|
from torchmetrics.functional.text import char_error_rate
|
||||||
# cringe warning message
|
|
||||||
try:
|
|
||||||
from torchmetrics.text import CharErrorRate
|
|
||||||
except Exception as e:
|
|
||||||
from torchmetrics import CharErrorRate
|
|
||||||
|
|
||||||
def wer( audio, reference, language="auto", normalize=True, phonemize=True, **transcription_kwargs ):
|
def wer( audio, reference, language="auto", normalize=True, phonemize=True, **transcription_kwargs ):
|
||||||
if language == "auto":
|
if language == "auto":
|
||||||
|
@ -43,7 +38,13 @@ def wer( audio, reference, language="auto", normalize=True, phonemize=True, **tr
|
||||||
reference = encode( reference, language=language )
|
reference = encode( reference, language=language )
|
||||||
|
|
||||||
wer_score = word_error_rate([transcription], [reference]).item()
|
wer_score = word_error_rate([transcription], [reference]).item()
|
||||||
cer_score = CharErrorRate()([transcription], [reference]).item()
|
# un-normalize
|
||||||
|
wer_score *= len(reference.split())
|
||||||
|
|
||||||
|
cer_score = char_error_rate([transcription], [reference]).item()
|
||||||
|
# un-normalize
|
||||||
|
cer_score *= len(reference)
|
||||||
|
|
||||||
return wer_score, cer_score
|
return wer_score, cer_score
|
||||||
|
|
||||||
def sim_o( audio, reference, **kwargs ):
|
def sim_o( audio, reference, **kwargs ):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user