2024-12-11 02:13:21 +00:00
|
|
|
# handles objective metric calculations, such as WER and SIM-O
|
|
|
|
|
|
|
|
#from .emb.transcribe import transcribe
|
|
|
|
from .emb.similar import speaker_similarity_embedding
|
|
|
|
from .emb.transcribe import transcribe
|
2024-12-11 03:00:51 +00:00
|
|
|
from .emb.g2p import detect_language, coerce_to_hiragana, encode
|
2024-12-11 02:13:21 +00:00
|
|
|
from .data import normalize_text
|
|
|
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
from torcheval.metrics.functional import word_error_rate
|
2024-12-17 20:22:30 +00:00
|
|
|
from torchmetrics.functional.text import char_error_rate
|
2024-12-11 02:13:21 +00:00
|
|
|
|
2024-12-19 01:58:53 +00:00
|
|
|
import warnings
|
|
|
|
warnings.simplefilter(action='ignore', category=FutureWarning)
|
|
|
|
warnings.simplefilter(action='ignore', category=UserWarning)
|
|
|
|
|
2024-12-19 05:43:11 +00:00
|
|
|
def wer( audio, reference, language="auto", phonemize=True, normalize=True, **transcription_kwargs ):
|
2024-12-11 02:13:21 +00:00
|
|
|
if language == "auto":
|
|
|
|
language = detect_language( reference )
|
|
|
|
|
2024-12-11 03:00:51 +00:00
|
|
|
transcription = transcribe( audio, language=language, align=False, **transcription_kwargs )
|
2024-12-19 01:58:53 +00:00
|
|
|
|
2024-12-11 03:00:51 +00:00
|
|
|
if language == "auto":
|
|
|
|
language = transcription["language"]
|
2024-12-19 01:58:53 +00:00
|
|
|
|
2024-12-11 03:00:51 +00:00
|
|
|
transcription = transcription["text"]
|
2024-12-11 02:13:21 +00:00
|
|
|
|
|
|
|
# reference audio needs transcribing too
|
|
|
|
if isinstance( reference, Path ):
|
|
|
|
reference = transcribe( reference, language=language, align=False, **transcription_kwargs )["text"]
|
|
|
|
|
2024-12-11 03:00:51 +00:00
|
|
|
if language == "ja":
|
|
|
|
transcription = coerce_to_hiragana( transcription )
|
|
|
|
reference = coerce_to_hiragana( reference )
|
|
|
|
|
|
|
|
if phonemize:
|
|
|
|
transcription = encode( transcription, language=language )
|
|
|
|
reference = encode( reference, language=language )
|
2024-12-19 05:43:11 +00:00
|
|
|
elif normalize:
|
2024-12-19 01:58:53 +00:00
|
|
|
transcription = normalize_text( transcription, language=language )
|
|
|
|
reference = normalize_text( reference, language=language )
|
2024-12-11 02:13:21 +00:00
|
|
|
|
2024-12-11 03:00:51 +00:00
|
|
|
wer_score = word_error_rate([transcription], [reference]).item()
|
2024-12-17 20:22:30 +00:00
|
|
|
# un-normalize
|
|
|
|
wer_score *= len(reference.split())
|
|
|
|
|
|
|
|
cer_score = char_error_rate([transcription], [reference]).item()
|
|
|
|
# un-normalize
|
|
|
|
cer_score *= len(reference)
|
|
|
|
|
2024-12-11 03:00:51 +00:00
|
|
|
return wer_score, cer_score
|
2024-12-11 02:13:21 +00:00
|
|
|
|
|
|
|
def sim_o( audio, reference, **kwargs ):
|
|
|
|
audio_emb = speaker_similarity_embedding( audio, **kwargs )
|
|
|
|
reference_emb = speaker_similarity_embedding( reference, **kwargs )
|
|
|
|
|
2024-12-12 01:30:05 +00:00
|
|
|
return F.cosine_similarity( audio_emb, reference_emb, dim=-1 ).item()
|