Added CER, transcription/similarity model args in demo
This commit is contained in:
parent
8568a93dad
commit
6f1ee0c6fa
|
@ -35,6 +35,9 @@ from .utils import setup_logging
|
||||||
|
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
|
def mean( l ):
|
||||||
|
return sum(l) / len(l)
|
||||||
|
|
||||||
def encode(path):
|
def encode(path):
|
||||||
if path is None or not path.exists():
|
if path is None or not path.exists():
|
||||||
return ""
|
return ""
|
||||||
|
@ -132,6 +135,9 @@ def main():
|
||||||
parser.add_argument("--lora", action="store_true")
|
parser.add_argument("--lora", action="store_true")
|
||||||
parser.add_argument("--comparison", type=str, default=None)
|
parser.add_argument("--comparison", type=str, default=None)
|
||||||
|
|
||||||
|
parser.add_argument("--transcription-model", type=str, default="base")
|
||||||
|
parser.add_argument("--speaker-similarity-model", type=str, default="wavlm_base_plus")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
config = None
|
config = None
|
||||||
|
@ -149,6 +155,10 @@ def main():
|
||||||
args.preamble = "<br>".join([
|
args.preamble = "<br>".join([
|
||||||
'Below are some samples from my VALL-E implementation: <a href="https://git.ecker.tech/mrq/vall-e/">https://git.ecker.tech/mrq/vall-e/</a>.',
|
'Below are some samples from my VALL-E implementation: <a href="https://git.ecker.tech/mrq/vall-e/">https://git.ecker.tech/mrq/vall-e/</a>.',
|
||||||
'Unlike the original VALL-E demo page, I\'m placing emphasis on the input prompt, as the model adheres to it stronger than others.',
|
'Unlike the original VALL-E demo page, I\'m placing emphasis on the input prompt, as the model adheres to it stronger than others.',
|
||||||
|
f'Objective metrics are computed by transcribing ({args.transcription_model}) then comparing the word error rate on transcriptions (WER/CER), and computing the cosine similarities on embeddings through a speaker feature extraction model ({args.speaker_similarity_model}) (SIM-O)',
|
||||||
|
'<b>Total WER:</b> ${WER}'
|
||||||
|
'<b>Total CER:</b> ${CER}'
|
||||||
|
'<b>Total SIM-O:</b> ${SIM-O}'
|
||||||
])
|
])
|
||||||
|
|
||||||
# comparison kwargs
|
# comparison kwargs
|
||||||
|
@ -384,17 +394,18 @@ def main():
|
||||||
process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) )
|
process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) )
|
||||||
|
|
||||||
metrics_map = {}
|
metrics_map = {}
|
||||||
for text, language, out_path, reference_path in metrics_inputs:
|
total_metrics = (0, 0)
|
||||||
wer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name="base" )
|
for text, language, out_path, reference_path in tqdm(metrics_inputs, desc="Calculating metrics"):
|
||||||
sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype )
|
wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model )
|
||||||
metrics_map[out_path] = (wer_score, sim_o_score)
|
sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, feat_type=args.speaker_similarity_model )
|
||||||
|
metrics_map[out_path] = (wer_score, cer_score, sim_o_score)
|
||||||
|
|
||||||
# collate entries into HTML
|
# collate entries into HTML
|
||||||
for k, samples in outputs:
|
for k, samples in outputs:
|
||||||
samples = [
|
samples = [
|
||||||
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
|
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
|
||||||
"".join([
|
"".join([
|
||||||
f'\n\t\t\t\t<td>{metrics_map[audios[1]][0]:.3f}</td><td>{metrics_map[audios[1]][1]:.3f}</td>'
|
f'\n\t\t\t\t<td>{metrics_map[audios[1]][0]:.3f}</td><td>{metrics_map[audios[1]][1]:.3f}</td><td>{metrics_map[audios[1]][2]:.3f}</td>'
|
||||||
] ) +
|
] ) +
|
||||||
"".join( [
|
"".join( [
|
||||||
f'\n\t\t\t\t<td><audio controls="controls" preload="none"><source src="{str(audio).replace(str(args.demo_dir), args.audio_path_root) if args.audio_path_root else encode(audio)}"/></audio></td>'
|
f'\n\t\t\t\t<td><audio controls="controls" preload="none"><source src="{str(audio).replace(str(args.demo_dir), args.audio_path_root) if args.audio_path_root else encode(audio)}"/></audio></td>'
|
||||||
|
@ -406,6 +417,10 @@ def main():
|
||||||
|
|
||||||
# write audio into template
|
# write audio into template
|
||||||
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
|
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
|
||||||
|
|
||||||
|
html = html.replace("${WER}", f'{mean([ metrics[0] for metrics in metrics_map.values() ]):.3f}' )
|
||||||
|
html = html.replace("${CER}", f'{mean([ metrics[1] for metrics in metrics_map.values() ]):.3f}' )
|
||||||
|
html = html.replace("${SIM-O}", f'{mean([ metrics[2] for metrics in metrics_map.values() ]):.3f}' )
|
||||||
|
|
||||||
if args.comparison:
|
if args.comparison:
|
||||||
disabled, enabled = comparison_kwargs["titles"]
|
disabled, enabled = comparison_kwargs["titles"]
|
||||||
|
|
|
@ -46,8 +46,24 @@ tts = None
|
||||||
|
|
||||||
# this is for computing SIM-O, but can probably technically be used for scoring similar utterances
|
# this is for computing SIM-O, but can probably technically be used for scoring similar utterances
|
||||||
@cache
|
@cache
|
||||||
def _load_sim_model(device="cuda", dtype="float16", feat_type="wavlm_base_plus", feat_dim=768):
|
def _load_sim_model(device="cuda", dtype="float16", feat_type="wavlm_base_plus", feat_dim="auto"):
|
||||||
|
logging.getLogger('s3prl').setLevel(logging.DEBUG)
|
||||||
|
logging.getLogger('speechbrain').setLevel(logging.DEBUG)
|
||||||
|
|
||||||
from ..utils.ext.ecapa_tdnn import ECAPA_TDNN_SMALL
|
from ..utils.ext.ecapa_tdnn import ECAPA_TDNN_SMALL
|
||||||
|
|
||||||
|
if feat_dim == "auto":
|
||||||
|
if feat_type == "fbank":
|
||||||
|
feat_dim = 40
|
||||||
|
elif feat_type == "wavlm_base_plus":
|
||||||
|
feat_dim = 768
|
||||||
|
elif feat_type == "wavlm_large":
|
||||||
|
feat_dim = 1024
|
||||||
|
elif feat_type == "hubert_large_ll60k":
|
||||||
|
feat_dim = 1024
|
||||||
|
elif feat_type == "wav2vec2_xlsr":
|
||||||
|
feat_dim = 1024
|
||||||
|
|
||||||
model = ECAPA_TDNN_SMALL(feat_dim=feat_dim, feat_type=feat_type, config_path=None)
|
model = ECAPA_TDNN_SMALL(feat_dim=feat_dim, feat_type=feat_type, config_path=None)
|
||||||
model = model.to(device=device, dtype=coerce_dtype(dtype))
|
model = model.to(device=device, dtype=coerce_dtype(dtype))
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
|
|
@ -3,28 +3,43 @@
|
||||||
#from .emb.transcribe import transcribe
|
#from .emb.transcribe import transcribe
|
||||||
from .emb.similar import speaker_similarity_embedding
|
from .emb.similar import speaker_similarity_embedding
|
||||||
from .emb.transcribe import transcribe
|
from .emb.transcribe import transcribe
|
||||||
from .emb.g2p import detect_language
|
from .emb.g2p import detect_language, coerce_to_hiragana, encode
|
||||||
from .data import normalize_text
|
from .data import normalize_text
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from torcheval.metrics.functional import word_error_rate
|
from torcheval.metrics.functional import word_error_rate
|
||||||
|
from torchmetrics import CharErrorRate
|
||||||
|
|
||||||
def wer( audio, reference, language="auto", **transcription_kwargs ):
|
def wer( audio, reference, language="auto", normalize=True, phonemize=True, **transcription_kwargs ):
|
||||||
if language == "auto":
|
if language == "auto":
|
||||||
language = detect_language( reference )
|
language = detect_language( reference )
|
||||||
|
|
||||||
transcription = transcribe( audio, language=language, align=False, **transcription_kwargs )["text"]
|
transcription = transcribe( audio, language=language, align=False, **transcription_kwargs )
|
||||||
|
if language == "auto":
|
||||||
|
language = transcription["language"]
|
||||||
|
transcription = transcription["text"]
|
||||||
|
|
||||||
# reference audio needs transcribing too
|
# reference audio needs transcribing too
|
||||||
if isinstance( reference, Path ):
|
if isinstance( reference, Path ):
|
||||||
reference = transcribe( reference, language=language, align=False, **transcription_kwargs )["text"]
|
reference = transcribe( reference, language=language, align=False, **transcription_kwargs )["text"]
|
||||||
|
|
||||||
transcription = normalize_text( transcription )
|
if language == "ja":
|
||||||
reference = normalize_text( reference )
|
transcription = coerce_to_hiragana( transcription )
|
||||||
|
reference = coerce_to_hiragana( reference )
|
||||||
|
|
||||||
return word_error_rate([transcription], [reference]).item()
|
if normalize:
|
||||||
|
transcription = normalize_text( transcription )
|
||||||
|
reference = normalize_text( reference )
|
||||||
|
|
||||||
|
if phonemize:
|
||||||
|
transcription = encode( transcription, language=language )
|
||||||
|
reference = encode( reference, language=language )
|
||||||
|
|
||||||
|
wer_score = word_error_rate([transcription], [reference]).item()
|
||||||
|
cer_score = CharErrorRate()([transcription], [reference]).item()
|
||||||
|
return wer_score, cer_score
|
||||||
|
|
||||||
def sim_o( audio, reference, **kwargs ):
|
def sim_o( audio, reference, **kwargs ):
|
||||||
audio_emb = speaker_similarity_embedding( audio, **kwargs )
|
audio_emb = speaker_similarity_embedding( audio, **kwargs )
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
# borrowed with love from "https://github.com/keonlee9420/evaluate-zero-shot-tts/blob/master/src/evaluate_zero_shot_tts/utils/speaker_verification/models/ecapa_tdnn.py"
|
# borrowed with love from "https://github.com/keonlee9420/evaluate-zero-shot-tts/blob/master/src/evaluate_zero_shot_tts/utils/speaker_verification/models/ecapa_tdnn.py"
|
||||||
|
# (which was from https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/ecapa_tdnn.py)
|
||||||
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
|
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
Loading…
Reference in New Issue
Block a user