diff --git a/vall_e/demo.py b/vall_e/demo.py
index a050102..13d4160 100644
--- a/vall_e/demo.py
+++ b/vall_e/demo.py
@@ -35,6 +35,9 @@ from .utils import setup_logging
from tqdm import tqdm, trange
+def mean( l ):
+ return sum(l) / len(l)
+
def encode(path):
if path is None or not path.exists():
return ""
@@ -132,6 +135,9 @@ def main():
parser.add_argument("--lora", action="store_true")
parser.add_argument("--comparison", type=str, default=None)
+ parser.add_argument("--transcription-model", type=str, default="base")
+ parser.add_argument("--speaker-similarity-model", type=str, default="wavlm_base_plus")
+
args = parser.parse_args()
config = None
@@ -149,6 +155,10 @@ def main():
args.preamble = "
".join([
'Below are some samples from my VALL-E implementation: https://git.ecker.tech/mrq/vall-e/.',
'Unlike the original VALL-E demo page, I\'m placing emphasis on the input prompt, as the model adheres to it stronger than others.',
+ f'Objective metrics are computed by transcribing ({args.transcription_model}) then comparing the word error rate on transcriptions (WER/CER), and computing the cosine similarities on embeddings through a speaker feature extraction model ({args.speaker_similarity_model}) (SIM-O)',
+ 'Total WER: ${WER}'
+ 'Total CER: ${CER}'
+ 'Total SIM-O: ${SIM-O}'
])
# comparison kwargs
@@ -384,17 +394,18 @@ def main():
process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) )
metrics_map = {}
- for text, language, out_path, reference_path in metrics_inputs:
- wer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name="base" )
- sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype )
- metrics_map[out_path] = (wer_score, sim_o_score)
+ total_metrics = (0, 0)
+ for text, language, out_path, reference_path in tqdm(metrics_inputs, desc="Calculating metrics"):
+ wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model )
+ sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, feat_type=args.speaker_similarity_model )
+ metrics_map[out_path] = (wer_score, cer_score, sim_o_score)
# collate entries into HTML
for k, samples in outputs:
samples = [
f'\n\t\t\t
\n\t\t\t\t{text} | '+
"".join([
- f'\n\t\t\t\t{metrics_map[audios[1]][0]:.3f} | {metrics_map[audios[1]][1]:.3f} | '
+ f'\n\t\t\t\t{metrics_map[audios[1]][0]:.3f} | {metrics_map[audios[1]][1]:.3f} | {metrics_map[audios[1]][2]:.3f} | '
] ) +
"".join( [
f'\n\t\t\t\t | '
@@ -406,6 +417,10 @@ def main():
# write audio into template
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
+
+ html = html.replace("${WER}", f'{mean([ metrics[0] for metrics in metrics_map.values() ]):.3f}' )
+ html = html.replace("${CER}", f'{mean([ metrics[1] for metrics in metrics_map.values() ]):.3f}' )
+ html = html.replace("${SIM-O}", f'{mean([ metrics[2] for metrics in metrics_map.values() ]):.3f}' )
if args.comparison:
disabled, enabled = comparison_kwargs["titles"]
diff --git a/vall_e/emb/similar.py b/vall_e/emb/similar.py
index 42b16db..8648770 100644
--- a/vall_e/emb/similar.py
+++ b/vall_e/emb/similar.py
@@ -46,8 +46,24 @@ tts = None
# this is for computing SIM-O, but can probably technically be used for scoring similar utterances
@cache
-def _load_sim_model(device="cuda", dtype="float16", feat_type="wavlm_base_plus", feat_dim=768):
+def _load_sim_model(device="cuda", dtype="float16", feat_type="wavlm_base_plus", feat_dim="auto"):
+ logging.getLogger('s3prl').setLevel(logging.DEBUG)
+ logging.getLogger('speechbrain').setLevel(logging.DEBUG)
+
from ..utils.ext.ecapa_tdnn import ECAPA_TDNN_SMALL
+
+ if feat_dim == "auto":
+ if feat_type == "fbank":
+ feat_dim = 40
+ elif feat_type == "wavlm_base_plus":
+ feat_dim = 768
+ elif feat_type == "wavlm_large":
+ feat_dim = 1024
+ elif feat_type == "hubert_large_ll60k":
+ feat_dim = 1024
+ elif feat_type == "wav2vec2_xlsr":
+ feat_dim = 1024
+
model = ECAPA_TDNN_SMALL(feat_dim=feat_dim, feat_type=feat_type, config_path=None)
model = model.to(device=device, dtype=coerce_dtype(dtype))
model = model.eval()
diff --git a/vall_e/metrics.py b/vall_e/metrics.py
index 67388b8..032e38b 100644
--- a/vall_e/metrics.py
+++ b/vall_e/metrics.py
@@ -3,28 +3,43 @@
#from .emb.transcribe import transcribe
from .emb.similar import speaker_similarity_embedding
from .emb.transcribe import transcribe
-from .emb.g2p import detect_language
+from .emb.g2p import detect_language, coerce_to_hiragana, encode
from .data import normalize_text
import torch.nn.functional as F
from pathlib import Path
from torcheval.metrics.functional import word_error_rate
+from torchmetrics import CharErrorRate
-def wer( audio, reference, language="auto", **transcription_kwargs ):
+def wer( audio, reference, language="auto", normalize=True, phonemize=True, **transcription_kwargs ):
if language == "auto":
language = detect_language( reference )
- transcription = transcribe( audio, language=language, align=False, **transcription_kwargs )["text"]
+ transcription = transcribe( audio, language=language, align=False, **transcription_kwargs )
+ if language == "auto":
+ language = transcription["language"]
+ transcription = transcription["text"]
# reference audio needs transcribing too
if isinstance( reference, Path ):
reference = transcribe( reference, language=language, align=False, **transcription_kwargs )["text"]
- transcription = normalize_text( transcription )
- reference = normalize_text( reference )
+ if language == "ja":
+ transcription = coerce_to_hiragana( transcription )
+ reference = coerce_to_hiragana( reference )
- return word_error_rate([transcription], [reference]).item()
+ if normalize:
+ transcription = normalize_text( transcription )
+ reference = normalize_text( reference )
+
+ if phonemize:
+ transcription = encode( transcription, language=language )
+ reference = encode( reference, language=language )
+
+ wer_score = word_error_rate([transcription], [reference]).item()
+ cer_score = CharErrorRate()([transcription], [reference]).item()
+ return wer_score, cer_score
def sim_o( audio, reference, **kwargs ):
audio_emb = speaker_similarity_embedding( audio, **kwargs )
diff --git a/vall_e/utils/ext/ecapa_tdnn.py b/vall_e/utils/ext/ecapa_tdnn.py
index 29a99fb..8f28f69 100644
--- a/vall_e/utils/ext/ecapa_tdnn.py
+++ b/vall_e/utils/ext/ecapa_tdnn.py
@@ -1,4 +1,5 @@
# borrowed with love from "https://github.com/keonlee9420/evaluate-zero-shot-tts/blob/master/src/evaluate_zero_shot_tts/utils/speaker_verification/models/ecapa_tdnn.py"
+# (which was from https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/ecapa_tdnn.py)
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
import torch