Added CER, transcription/similarity model args in demo
This commit is contained in:
parent
8568a93dad
commit
6f1ee0c6fa
|
@ -35,6 +35,9 @@ from .utils import setup_logging
|
|||
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
def mean( l ):
|
||||
return sum(l) / len(l)
|
||||
|
||||
def encode(path):
|
||||
if path is None or not path.exists():
|
||||
return ""
|
||||
|
@ -132,6 +135,9 @@ def main():
|
|||
parser.add_argument("--lora", action="store_true")
|
||||
parser.add_argument("--comparison", type=str, default=None)
|
||||
|
||||
parser.add_argument("--transcription-model", type=str, default="base")
|
||||
parser.add_argument("--speaker-similarity-model", type=str, default="wavlm_base_plus")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = None
|
||||
|
@ -149,6 +155,10 @@ def main():
|
|||
args.preamble = "<br>".join([
|
||||
'Below are some samples from my VALL-E implementation: <a href="https://git.ecker.tech/mrq/vall-e/">https://git.ecker.tech/mrq/vall-e/</a>.',
|
||||
'Unlike the original VALL-E demo page, I\'m placing emphasis on the input prompt, as the model adheres to it stronger than others.',
|
||||
f'Objective metrics are computed by transcribing ({args.transcription_model}) then comparing the word error rate on transcriptions (WER/CER), and computing the cosine similarities on embeddings through a speaker feature extraction model ({args.speaker_similarity_model}) (SIM-O)',
|
||||
'<b>Total WER:</b> ${WER}'
|
||||
'<b>Total CER:</b> ${CER}'
|
||||
'<b>Total SIM-O:</b> ${SIM-O}'
|
||||
])
|
||||
|
||||
# comparison kwargs
|
||||
|
@ -384,17 +394,18 @@ def main():
|
|||
process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) )
|
||||
|
||||
metrics_map = {}
|
||||
for text, language, out_path, reference_path in metrics_inputs:
|
||||
wer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name="base" )
|
||||
sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype )
|
||||
metrics_map[out_path] = (wer_score, sim_o_score)
|
||||
total_metrics = (0, 0)
|
||||
for text, language, out_path, reference_path in tqdm(metrics_inputs, desc="Calculating metrics"):
|
||||
wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model )
|
||||
sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, feat_type=args.speaker_similarity_model )
|
||||
metrics_map[out_path] = (wer_score, cer_score, sim_o_score)
|
||||
|
||||
# collate entries into HTML
|
||||
for k, samples in outputs:
|
||||
samples = [
|
||||
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
|
||||
"".join([
|
||||
f'\n\t\t\t\t<td>{metrics_map[audios[1]][0]:.3f}</td><td>{metrics_map[audios[1]][1]:.3f}</td>'
|
||||
f'\n\t\t\t\t<td>{metrics_map[audios[1]][0]:.3f}</td><td>{metrics_map[audios[1]][1]:.3f}</td><td>{metrics_map[audios[1]][2]:.3f}</td>'
|
||||
] ) +
|
||||
"".join( [
|
||||
f'\n\t\t\t\t<td><audio controls="controls" preload="none"><source src="{str(audio).replace(str(args.demo_dir), args.audio_path_root) if args.audio_path_root else encode(audio)}"/></audio></td>'
|
||||
|
@ -406,6 +417,10 @@ def main():
|
|||
|
||||
# write audio into template
|
||||
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
|
||||
|
||||
html = html.replace("${WER}", f'{mean([ metrics[0] for metrics in metrics_map.values() ]):.3f}' )
|
||||
html = html.replace("${CER}", f'{mean([ metrics[1] for metrics in metrics_map.values() ]):.3f}' )
|
||||
html = html.replace("${SIM-O}", f'{mean([ metrics[2] for metrics in metrics_map.values() ]):.3f}' )
|
||||
|
||||
if args.comparison:
|
||||
disabled, enabled = comparison_kwargs["titles"]
|
||||
|
|
|
@ -46,8 +46,24 @@ tts = None
|
|||
|
||||
# this is for computing SIM-O, but can probably technically be used for scoring similar utterances
|
||||
@cache
|
||||
def _load_sim_model(device="cuda", dtype="float16", feat_type="wavlm_base_plus", feat_dim=768):
|
||||
def _load_sim_model(device="cuda", dtype="float16", feat_type="wavlm_base_plus", feat_dim="auto"):
|
||||
logging.getLogger('s3prl').setLevel(logging.DEBUG)
|
||||
logging.getLogger('speechbrain').setLevel(logging.DEBUG)
|
||||
|
||||
from ..utils.ext.ecapa_tdnn import ECAPA_TDNN_SMALL
|
||||
|
||||
if feat_dim == "auto":
|
||||
if feat_type == "fbank":
|
||||
feat_dim = 40
|
||||
elif feat_type == "wavlm_base_plus":
|
||||
feat_dim = 768
|
||||
elif feat_type == "wavlm_large":
|
||||
feat_dim = 1024
|
||||
elif feat_type == "hubert_large_ll60k":
|
||||
feat_dim = 1024
|
||||
elif feat_type == "wav2vec2_xlsr":
|
||||
feat_dim = 1024
|
||||
|
||||
model = ECAPA_TDNN_SMALL(feat_dim=feat_dim, feat_type=feat_type, config_path=None)
|
||||
model = model.to(device=device, dtype=coerce_dtype(dtype))
|
||||
model = model.eval()
|
||||
|
|
|
@ -3,28 +3,43 @@
|
|||
#from .emb.transcribe import transcribe
|
||||
from .emb.similar import speaker_similarity_embedding
|
||||
from .emb.transcribe import transcribe
|
||||
from .emb.g2p import detect_language
|
||||
from .emb.g2p import detect_language, coerce_to_hiragana, encode
|
||||
from .data import normalize_text
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from pathlib import Path
|
||||
from torcheval.metrics.functional import word_error_rate
|
||||
from torchmetrics import CharErrorRate
|
||||
|
||||
def wer( audio, reference, language="auto", **transcription_kwargs ):
|
||||
def wer( audio, reference, language="auto", normalize=True, phonemize=True, **transcription_kwargs ):
|
||||
if language == "auto":
|
||||
language = detect_language( reference )
|
||||
|
||||
transcription = transcribe( audio, language=language, align=False, **transcription_kwargs )["text"]
|
||||
transcription = transcribe( audio, language=language, align=False, **transcription_kwargs )
|
||||
if language == "auto":
|
||||
language = transcription["language"]
|
||||
transcription = transcription["text"]
|
||||
|
||||
# reference audio needs transcribing too
|
||||
if isinstance( reference, Path ):
|
||||
reference = transcribe( reference, language=language, align=False, **transcription_kwargs )["text"]
|
||||
|
||||
transcription = normalize_text( transcription )
|
||||
reference = normalize_text( reference )
|
||||
if language == "ja":
|
||||
transcription = coerce_to_hiragana( transcription )
|
||||
reference = coerce_to_hiragana( reference )
|
||||
|
||||
return word_error_rate([transcription], [reference]).item()
|
||||
if normalize:
|
||||
transcription = normalize_text( transcription )
|
||||
reference = normalize_text( reference )
|
||||
|
||||
if phonemize:
|
||||
transcription = encode( transcription, language=language )
|
||||
reference = encode( reference, language=language )
|
||||
|
||||
wer_score = word_error_rate([transcription], [reference]).item()
|
||||
cer_score = CharErrorRate()([transcription], [reference]).item()
|
||||
return wer_score, cer_score
|
||||
|
||||
def sim_o( audio, reference, **kwargs ):
|
||||
audio_emb = speaker_similarity_embedding( audio, **kwargs )
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# borrowed with love from "https://github.com/keonlee9420/evaluate-zero-shot-tts/blob/master/src/evaluate_zero_shot_tts/utils/speaker_verification/models/ecapa_tdnn.py"
|
||||
# (which was from https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/ecapa_tdnn.py)
|
||||
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
|
||||
|
||||
import torch
|
||||
|
|
Loading…
Reference in New Issue
Block a user