2024-07-20 01:49:40 +00:00
"""
A helper script to generate a demo page .
Layout as expected :
. / data / demo / :
{ speaker ID } :
out :
ours . wav ( generated )
ms_valle . wav
yourtts . wav
prompt . txt ( text to generate )
prompt . wav ( reference clip to serve as the prompt )
reference . wav ( ground truth utterance )
Will also generate samples from a provided datset , if requested .
"""
import argparse
import base64
import random
2024-08-29 18:27:16 +00:00
import logging
2024-10-11 00:04:12 +00:00
import time
2024-12-11 02:13:21 +00:00
import torch
2024-08-29 18:27:16 +00:00
_logger = logging . getLogger ( __name__ )
2024-07-20 01:49:40 +00:00
from pathlib import Path
from . inference import TTS
from . config import cfg
2024-10-10 18:40:25 +00:00
from . data import create_train_dataloader , create_val_dataloader , get_random_prompt
2024-07-20 01:49:40 +00:00
from . emb . qnt import decode_to_file
2024-12-11 02:13:21 +00:00
from . metrics import wer , sim_o
from . utils import setup_logging
2024-12-12 02:55:43 +00:00
from . utils . io import json_read , json_write
2024-07-20 01:49:40 +00:00
2024-07-22 00:12:03 +00:00
from tqdm import tqdm , trange
2024-07-20 01:49:40 +00:00
2024-12-11 03:00:51 +00:00
def mean ( l ) :
return sum ( l ) / len ( l )
2024-07-20 01:49:40 +00:00
def encode ( path ) :
2024-10-22 00:52:02 +00:00
if path is None or not path . exists ( ) :
2024-10-18 18:19:36 +00:00
return " "
2024-07-20 01:49:40 +00:00
return " data:audio/wav;base64, " + base64 . b64encode ( open ( path , " rb " ) . read ( ) ) . decode ( ' utf-8 ' )
2024-12-08 01:21:05 +00:00
def safe_inference ( tts , out_path , * * kwargs ) :
if args . skip_existing and out_path . exists ( ) :
return
try :
tts . inference ( out_path = out_path , * * kwargs )
except Exception as e :
raise e
print ( f ' Error while processing { out_path } : { e } ' )
def safe_batched_inference ( tts , * * kwargs ) :
try :
tts . batched_inference ( * * kwargs )
except Exception as e :
raise e
print ( f ' Error while processing batch: { e } ' )
def process_batch ( tts , inputs , kwargs = { } ) :
kwargs = kwargs | dict (
texts = [ x [ 0 ] for x in inputs ] ,
references = [ x [ 1 ] for x in inputs ] ,
languages = [ x [ 2 ] for x in inputs ] ,
out_paths = [ x [ 3 ] for x in inputs ] ,
)
safe_batched_inference ( tts , * * kwargs )
2024-07-20 01:49:40 +00:00
# Would be downright sugoi if I could incorporate this with into __main__
def main ( ) :
parser = argparse . ArgumentParser ( " VALL-E TTS Demo " )
parser . add_argument ( " --yaml " , type = Path , default = None )
2024-10-26 05:13:10 +00:00
parser . add_argument ( " --model " , type = Path , default = None )
2024-12-08 04:34:25 +00:00
parser . add_argument ( " --batch-size " , type = int , default = cfg . inference . batch_size )
2024-07-20 01:49:40 +00:00
parser . add_argument ( " --demo-dir " , type = Path , default = None )
parser . add_argument ( " --skip-existing " , action = " store_true " )
2024-12-18 04:47:12 +00:00
parser . add_argument ( " --dataset-dir-name " , type = str , default = " " )
2024-10-22 23:12:39 +00:00
parser . add_argument ( " --dataset-dir-name-prefix " , type = str , default = None )
2024-07-20 01:49:40 +00:00
parser . add_argument ( " --sample-from-dataset " , action = " store_true " )
2024-09-28 14:49:45 +00:00
parser . add_argument ( " --skip-loading-dataloader " , action = " store_true " )
2024-07-20 01:49:40 +00:00
parser . add_argument ( " --dataset-samples " , type = int , default = 0 )
parser . add_argument ( " --audio-path-root " , type = str , default = None )
2024-07-22 00:12:03 +00:00
parser . add_argument ( " --preamble " , type = str , default = None )
2024-10-11 00:40:01 +00:00
parser . add_argument ( " --output-filename " , type = str , default = " index.html " )
2024-07-20 01:49:40 +00:00
2024-12-08 04:34:25 +00:00
parser . add_argument ( " --language " , type = str , default = " auto " )
2024-11-13 04:30:09 +00:00
parser . add_argument ( " --task " , type = str , default = " tts " )
2024-11-20 00:51:17 +00:00
parser . add_argument ( " --modality " , type = str , default = " auto " )
2024-11-13 04:30:09 +00:00
parser . add_argument ( " --out-path " , type = Path , default = None )
parser . add_argument ( " --max-duration " , type = int , default = 12 * cfg . dataset . frames_per_second )
2024-12-17 04:54:53 +00:00
parser . add_argument ( " --max-steps " , type = int , default = 30 )
2024-11-13 04:30:09 +00:00
parser . add_argument ( " --max-levels " , type = int , default = 7 )
2024-07-20 01:49:40 +00:00
2024-11-13 04:30:09 +00:00
parser . add_argument ( " --ar-temperature " , type = float , default = 1.0 )
parser . add_argument ( " --nar-temperature " , type = float , default = 0.0 )
parser . add_argument ( " --min-ar-temperature " , type = float , default = - 1.0 )
parser . add_argument ( " --min-nar-temperature " , type = float , default = - 1.0 )
2024-12-08 01:21:05 +00:00
parser . add_argument ( " --input-prompt-length " , type = float , default = 5.0 )
2024-11-13 04:30:09 +00:00
parser . add_argument ( " --input-prompt-prefix " , action = " store_true " )
parser . add_argument ( " --prefix-silence " , type = float , default = 0.0 )
2024-11-21 02:37:33 +00:00
parser . add_argument ( " --cfg-strength " , type = float , default = 1.0 )
parser . add_argument ( " --cfg-rescale " , type = float , default = 0.75 )
2024-07-20 01:49:40 +00:00
parser . add_argument ( " --top-p " , type = float , default = 1.0 )
2024-07-22 04:21:37 +00:00
parser . add_argument ( " --top-k " , type = int , default = 0 )
2024-11-13 04:30:09 +00:00
parser . add_argument ( " --top-no " , type = float , default = 0.0 )
2024-10-12 17:09:17 +00:00
parser . add_argument ( " --min-p " , type = float , default = 0.0 )
2024-11-13 04:30:09 +00:00
parser . add_argument ( " --repetition-penalty " , type = float , default = 1.0 )
2024-07-20 01:49:40 +00:00
parser . add_argument ( " --repetition-penalty-decay " , type = float , default = 0.0 )
parser . add_argument ( " --length-penalty " , type = float , default = 0.0 )
parser . add_argument ( " --beam-width " , type = int , default = 0 )
parser . add_argument ( " --mirostat-tau " , type = float , default = 0 )
parser . add_argument ( " --mirostat-eta " , type = float , default = 0 )
2024-11-13 04:30:09 +00:00
2024-08-05 00:56:21 +00:00
parser . add_argument ( " --dry-multiplier " , type = float , default = 0 )
parser . add_argument ( " --dry-base " , type = float , default = 1.75 )
parser . add_argument ( " --dry-allowed-length " , type = int , default = 2 )
2024-07-20 01:49:40 +00:00
parser . add_argument ( " --seed " , type = int , default = None )
parser . add_argument ( " --device " , type = str , default = None )
parser . add_argument ( " --amp " , action = " store_true " )
parser . add_argument ( " --dtype " , type = str , default = None )
2024-12-08 20:52:47 +00:00
parser . add_argument ( " --attention " , type = str , default = " auto " )
2024-10-10 18:40:25 +00:00
parser . add_argument ( " --random-prompts " , action = " store_true " )
parser . add_argument ( " --lora " , action = " store_true " )
2024-10-18 14:40:06 +00:00
parser . add_argument ( " --comparison " , type = str , default = None )
2024-07-20 01:49:40 +00:00
2024-12-19 01:58:53 +00:00
parser . add_argument ( " --transcription-model " , type = str , default = " openai/whisper-large-v3 " )
2024-12-17 16:11:14 +00:00
parser . add_argument ( " --speaker-similarity-model " , type = str , default = " microsoft/wavlm-large " )
2024-12-11 03:00:51 +00:00
2024-07-20 01:49:40 +00:00
args = parser . parse_args ( )
2024-10-26 05:13:10 +00:00
config = None
if args . yaml :
config = args . yaml
elif args . model :
config = args . model
2024-07-20 01:49:40 +00:00
2024-12-08 20:52:47 +00:00
tts = TTS ( config = config , lora = args . lora , device = args . device , dtype = args . dtype , amp = args . amp , attention = args . attention )
2024-07-20 01:49:40 +00:00
if not args . demo_dir :
args . demo_dir = Path ( " ./data/demo/ " )
2024-07-22 00:12:03 +00:00
if not args . preamble :
2024-12-18 04:47:12 +00:00
args . preamble = " \n " . join ( [
" Past model demo pages: <a href= \" /models/ar+nar-len-llama-8 (ar+nar).html \" >[ar+nar-len-llama-8 (ar+nar)]</a> <a href= \" /models/ar+nar-len-llama-8 (nar-len).html \" >[ar+nar-len-llama-8 (nar-len)]</a> <a href= \" /models/ar+nar-llama-8 (ar+nar).html \" >[ar+nar-llama-8 (ar+nar)]</a> | Old demo pages: <a href= \" /loras/index.html \" >[1]</a> <a href= \" /loras.html \" >[2]</a> <a href= \" /old/2024.10.25.html \" >[3]</a> <a href= \" /old/2024.12.15.html \" >[4]</a> <a href= \" /old/2024.12.16.html \" >[5]</a> " ,
" <br> " ,
" <br> " ,
" Below are some samples from my VALL-E implementation: <a href= \" https://git.ecker.tech/mrq/vall-e/ \" >https://git.ecker.tech/mrq/vall-e/</a>. " ,
" <br> " ,
" Objective metrics are computed by: " ,
" <ul> " ,
" <li>WER/CER: transcribing (openai/whisper-base) then comparing the un-normalized word error rate on the phonemized transcriptions.</li> " ,
" <li>SIM-O: retrieving the speaker embeddings of the output audio and the input prompt, from a finetune of WavLM for speaker verification (microsoft/wavlm-large), and computing the cosine similarity between the embeddings.</li> " ,
" </ul> " ,
" Tables marked as \" Validation \" are speakers/samples not seen to the model. " ,
" <br> " ,
f " These samples were generated using <pre>--ar-temperature= { args . ar_temperature } --nar-temperature= { args . nar_temperature } --cfg-strength= { args . cfg_strength } --max-steps= { args . max_steps } --top-k= { args . top_k } --dtype= { args . dtype } </pre> " ,
2024-07-22 00:12:03 +00:00
] )
2024-07-20 01:49:40 +00:00
2024-10-12 16:27:55 +00:00
# comparison kwargs
comparison_kwargs = {
" titles " : [ ] ,
2024-10-18 14:40:06 +00:00
" suffix " : " diff " ,
" enabled " : { } ,
" disabled " : { }
2024-10-12 16:27:55 +00:00
}
if args . lora :
2024-10-18 14:40:06 +00:00
args . comparison = " lora "
# to-do: just make this mappable
if args . comparison == " lora " :
2024-10-22 00:52:02 +00:00
comparison_kwargs [ " suffix " ] = " no_lora "
comparison_kwargs [ " titles " ] = [ " LoRA " , " No LoRA " ]
2024-10-12 16:27:55 +00:00
2024-10-22 00:52:02 +00:00
comparison_kwargs [ " disabled " ] [ " use_lora " ] = True
2024-11-13 04:30:09 +00:00
comparison_kwargs [ " disabled " ] [ " ar_temperature " ] = 0.0
2024-10-22 00:52:02 +00:00
comparison_kwargs [ " enabled " ] [ " use_lora " ] = False
2024-11-13 04:30:09 +00:00
comparison_kwargs [ " enabled " ] [ " ar_temperature " ] = 0.95
2024-10-18 14:40:06 +00:00
elif args . comparison == " ar-temp " :
2024-11-13 04:30:09 +00:00
current_temperature = args . ar_temperature
other_temperature = 1.0
2024-10-18 18:19:36 +00:00
2024-10-18 14:40:06 +00:00
comparison_kwargs [ " suffix " ] = " temperature "
2024-11-13 04:30:09 +00:00
comparison_kwargs [ " titles " ] = [ f " Temp: { current_temperature : .2f } " , f " Temp: { other_temperature : .2f } " ]
2024-10-18 14:40:06 +00:00
2024-11-13 04:30:09 +00:00
comparison_kwargs [ " disabled " ] [ " ar_temperature " ] = current_temperature
comparison_kwargs [ " enabled " ] [ " ar_temperature " ] = other_temperature
2024-10-18 14:40:06 +00:00
elif args . comparison == " input-prompt-length " :
2024-10-18 18:19:36 +00:00
current_length = args . input_prompt_length
other_length = 3.0
2024-10-18 14:40:06 +00:00
comparison_kwargs [ " suffix " ] = " input_prompt_length "
2024-10-18 18:19:36 +00:00
comparison_kwargs [ " titles " ] = [ f " Prompt Length: { current_length : .2f } s " , f " Prompt Length: { other_length : .2f } s " ]
comparison_kwargs [ " disabled " ] [ " input_prompt_length " ] = current_length
comparison_kwargs [ " enabled " ] [ " input_prompt_length " ] = other_length
elif args . comparison == " dtype " :
current_dtype = cfg . inference . weight_dtype
other_dtype = " float32 "
2024-10-18 21:58:56 +00:00
2024-10-18 18:19:36 +00:00
if current_dtype == " float16 " :
other_dtype = " bfloat16 "
elif current_dtype == " bfloat16 " :
other_dtype = " float16 "
comparison_kwargs [ " suffix " ] = f " dtype_ { other_dtype } "
comparison_kwargs [ " titles " ] = [ f " With { current_dtype } " , f " With { other_dtype } " ]
comparison_kwargs [ " disabled " ] [ " dtype " ] = current_dtype
comparison_kwargs [ " enabled " ] [ " dtype " ] = other_dtype
elif args . comparison == " amp " :
current_amp = cfg . inference . weight_amp
other_amp = not current_amp
2024-10-18 14:40:06 +00:00
2024-10-18 18:19:36 +00:00
comparison_kwargs [ " suffix " ] = f " with { ' out ' if not other_amp else ' ' } _amp "
comparison_kwargs [ " titles " ] = [ f " With { current_amp } " , f " With { other_amp } " ]
comparison_kwargs [ " disabled " ] [ " amp " ] = current_amp
comparison_kwargs [ " enabled " ] [ " amp " ] = other_amp
2024-12-03 01:10:42 +00:00
elif args . comparison == " modality " :
comparison_kwargs [ " suffix " ] = " modality "
comparison_kwargs [ " titles " ] = [ f " AR+NAR " , f " NAR-len " ]
comparison_kwargs [ " disabled " ] [ " modality " ] = " ar+nar "
comparison_kwargs [ " disabled " ] [ " cfg_strength " ] = 0.0
comparison_kwargs [ " enabled " ] [ " modality " ] = " nar-len "
comparison_kwargs [ " enabled " ] [ " cfg_strength " ] = 3.0
2024-11-17 16:23:40 +00:00
elif args . comparison == " cfg-strength " :
current_cfg_strength = 3.0
other_cfg_strength = 0.0
comparison_kwargs [ " suffix " ] = f " no_cfg_strength "
comparison_kwargs [ " titles " ] = [ f " CFG { current_cfg_strength } " , f " CFG { other_cfg_strength } " ]
comparison_kwargs [ " disabled " ] [ " cfg_strength " ] = current_cfg_strength
comparison_kwargs [ " enabled " ] [ " cfg_strength " ] = other_cfg_strength
2024-10-22 00:52:02 +00:00
elif args . comparison :
2024-10-18 14:40:06 +00:00
raise Exception ( f " Unrecognized comparison flag: { args . comparison } " )
2024-10-12 16:27:55 +00:00
2024-12-11 02:13:21 +00:00
setup_logging ( )
2024-07-20 01:49:40 +00:00
# read html template
html = open ( args . demo_dir / " index.template.html " , " r " , encoding = " utf-8 " ) . read ( )
2024-11-13 04:30:09 +00:00
sampling_kwargs = dict (
2024-11-20 00:51:17 +00:00
task = args . task ,
modality = args . modality ,
2024-11-13 04:30:09 +00:00
max_steps = args . max_steps ,
max_levels = args . max_levels ,
max_duration = args . max_duration ,
ar_temperature = args . ar_temperature , nar_temperature = args . nar_temperature ,
min_ar_temperature = args . min_ar_temperature , min_nar_temperature = args . min_nar_temperature ,
top_p = args . top_p , top_k = args . top_k , top_no = args . top_no , min_p = args . min_p ,
2024-07-22 00:31:13 +00:00
repetition_penalty = args . repetition_penalty , repetition_penalty_decay = args . repetition_penalty_decay ,
length_penalty = args . length_penalty ,
beam_width = args . beam_width ,
mirostat_tau = args . mirostat_tau , mirostat_eta = args . mirostat_eta ,
2024-08-05 00:56:21 +00:00
dry_multiplier = args . dry_multiplier , dry_base = args . dry_base , dry_allowed_length = args . dry_allowed_length ,
2024-11-13 04:30:09 +00:00
input_prompt_length = args . input_prompt_length ,
input_prompt_prefix = args . input_prompt_prefix ,
prefix_silence = args . prefix_silence ,
cfg_strength = args . cfg_strength ,
2024-11-21 02:37:33 +00:00
cfg_rescale = args . cfg_rescale ,
2024-12-08 01:21:05 +00:00
seed = args . seed if args . seed else int ( time . time ( ) ) ,
tqdm = True ,
batch_size = args . batch_size ,
2024-11-13 04:30:09 +00:00
)
# replace values in our template
html = html . replace ( r " $ {PREAMBLE} " , args . preamble )
html = html . replace ( r " $ {SETTINGS} " , str ( sampling_kwargs ) )
2024-07-22 00:12:03 +00:00
# pull from provided samples
2024-12-18 01:33:04 +00:00
samples_dirs = { }
# only add the existing librispeech validation dataset if i'm doing validation so I can stop commenting this out
2024-12-18 04:47:12 +00:00
if not args . dataset_dir_name :
2024-12-20 01:08:57 +00:00
samples_dirs [ " librispeech " ] = args . demo_dir / " datasets " / " librispeech "
2024-12-18 04:47:12 +00:00
else :
if " validation " in args . dataset_dir_name :
2024-12-20 01:08:57 +00:00
samples_dirs [ " librispeech " ] = args . demo_dir / " datasets " / " librispeech "
2024-12-18 04:47:12 +00:00
# automatically pull from anything under the dataset dir
if args . dataset_dir_name . endswith ( " /* " ) :
args . dataset_dir_name = args . dataset_dir_name [ : - 2 ]
datasets = [ dir for dir in ( args . demo_dir / args . dataset_dir_name ) . iterdir ( ) if dir . is_dir ( ) ]
for path in datasets :
samples_dirs [ path . name ] = path
# user provided dataset
elif ( args . demo_dir / args . dataset_dir_name ) . exists ( ) :
samples_dirs [ " dataset " ] = args . demo_dir / args . dataset_dir_name
2024-09-28 14:49:45 +00:00
2024-07-20 01:49:40 +00:00
# pull from dataset samples
if args . sample_from_dataset :
2024-10-12 16:27:55 +00:00
cfg . dataset . sample_type = " path " if len ( cfg . dataset . training ) < cfg . evaluation . batch_size else " speaker "
2024-10-22 23:12:39 +00:00
cfg . dataset . sample_order = " random "
2024-09-28 15:50:26 +00:00
cfg . dataset . tasks_list = [ ' tts ' ]
2024-10-11 00:04:12 +00:00
samples_dirs [ " dataset " ] = args . demo_dir / args . dataset_dir_name
2024-07-22 00:12:03 +00:00
2024-09-28 14:49:45 +00:00
_logger . info ( " Loading dataloader... " )
dataloader = create_train_dataloader ( )
_logger . info ( " Loaded dataloader. " )
2024-07-20 01:49:40 +00:00
2024-10-10 18:52:37 +00:00
length = min ( len ( dataloader . dataset ) , cfg . evaluation . batch_size )
2024-09-28 15:50:26 +00:00
num = args . dataset_samples if args . dataset_samples else length
2024-09-28 14:49:45 +00:00
for i in trange ( num , desc = " Sampling dataset for samples " ) :
2024-12-12 01:10:32 +00:00
index = i if not cfg . dataset . sample_shuffle else random . randint ( 0 , len ( dataloader . dataset ) - 1 )
batch = dataloader . dataset [ index ]
2024-07-20 01:49:40 +00:00
2024-10-22 23:12:39 +00:00
if args . dataset_dir_name_prefix :
dir = args . demo_dir / args . dataset_dir_name / f ' { args . dataset_dir_name_prefix } _ { i } '
else :
dir = args . demo_dir / args . dataset_dir_name / f ' { i } '
2024-07-20 01:49:40 +00:00
2024-09-28 14:49:45 +00:00
( dir / " out " ) . mkdir ( parents = True , exist_ok = True )
2024-07-20 01:49:40 +00:00
2024-09-28 14:49:45 +00:00
metadata = batch [ " metadata " ]
2024-07-22 00:12:03 +00:00
2024-10-31 01:05:45 +00:00
text = get_random_prompt ( ) if args . random_prompts else metadata [ " text " ]
#text = get_random_prompt() if i >= (num // 2) else metadata["text"]
2024-09-28 15:50:26 +00:00
language = metadata [ " language " ] . lower ( )
2024-09-28 14:49:45 +00:00
prompt = dir / " prompt.wav "
reference = dir / " reference.wav "
out_path = dir / " out " / " ours.wav "
2024-07-20 01:49:40 +00:00
2024-09-28 14:49:45 +00:00
if args . skip_existing and out_path . exists ( ) :
continue
2024-07-22 00:12:03 +00:00
2024-09-28 14:49:45 +00:00
open ( dir / " prompt.txt " , " w " , encoding = " utf-8 " ) . write ( text )
open ( dir / " language.txt " , " w " , encoding = " utf-8 " ) . write ( language )
2024-09-26 23:56:57 +00:00
2024-09-28 14:49:45 +00:00
decode_to_file ( batch [ " proms " ] . to ( " cuda " ) , prompt , device = " cuda " )
decode_to_file ( batch [ " resps " ] . to ( " cuda " ) , reference , device = " cuda " )
2024-07-22 00:12:03 +00:00
2024-12-08 01:21:05 +00:00
inputs = [ ]
outputs = [ ]
2024-12-11 02:13:21 +00:00
metrics_inputs = [ ]
2024-12-08 01:21:05 +00:00
comparison_inputs = [ ]
2024-12-18 01:33:04 +00:00
for dataset_name , sample_dir in samples_dirs . items ( ) :
2024-07-22 00:12:03 +00:00
if not sample_dir . exists ( ) :
continue
samples = [ ]
2024-10-18 14:40:06 +00:00
speakers = [ dir for dir in sample_dir . iterdir ( ) if dir . is_dir ( ) ]
2024-12-12 04:45:38 +00:00
speakers . sort ( )
2024-12-18 01:33:04 +00:00
sources = [ " ms_valle " , " f5 " ] if dataset_name == " librispeech " else [ ]
2024-07-22 00:12:03 +00:00
# generate demo output
2024-12-18 01:33:04 +00:00
for dir in tqdm ( speakers , desc = f " Preparing demo for { dataset_name } " ) :
2024-12-18 04:47:12 +00:00
# bail if too many samples
if args . dataset_samples and len ( samples ) > = args . dataset_samples :
break
2024-11-05 00:00:33 +00:00
text = open ( dir / " prompt.txt " , encoding = " utf-8 " ) . read ( )
2024-07-22 00:12:03 +00:00
language = open ( dir / " language.txt " ) . read ( ) if ( dir / " language.txt " ) . exists ( ) else " en "
prompt = dir / " prompt.wav "
2024-07-22 04:21:37 +00:00
reference = dir / " reference.wav "
2024-12-12 02:55:43 +00:00
metrics_path = dir / " metrics.json "
2024-07-22 00:12:03 +00:00
out_path = dir / " out " / " ours.wav "
2024-11-05 00:00:33 +00:00
out_path_comparison = dir / " out " / f " ours_ { comparison_kwargs [ ' suffix ' ] } .wav "
2024-10-18 14:40:06 +00:00
external_sources = [ dir / " out " / f " { source } .wav " for source in sources ]
2024-07-22 00:12:03 +00:00
2024-10-18 14:40:06 +00:00
audio_samples = [ prompt , out_path ]
if args . comparison :
audio_samples + = [ out_path_comparison ]
2024-10-18 22:23:22 +00:00
audio_samples + = [ p if p . exists ( ) else None for p in external_sources ]
2024-07-22 00:12:03 +00:00
2024-12-08 01:21:05 +00:00
"""
# manual invocation
cmd = f ' python3 -m vall_e --yaml= " { args . yaml } " " { reference } " " { text } " --out-path= { out_path } '
# F5
cmd = f ' python inference-cli.py --model " F5-TTS " --ref_audio " { reference } " --gen_text " { text } " --output_dir " { out_path . parent } " '
"""
2024-12-18 01:33:04 +00:00
if not args . random_prompts or dataset_name == " librispeech " :
2024-10-18 14:40:06 +00:00
audio_samples + = [ reference ]
2024-10-11 00:04:12 +00:00
2024-07-20 01:49:40 +00:00
samples . append ( (
text ,
2024-10-18 14:40:06 +00:00
audio_samples ,
2024-07-20 01:49:40 +00:00
) )
2024-12-08 01:21:05 +00:00
# segregate comparisons into its own batch because they use different kwargs (and I do not support variadic-batched kwargs)
if args . comparison :
2024-12-12 04:45:38 +00:00
should_generate = ( args . skip_existing and not out_path_comparison . exists ( ) ) or not ( args . skip_existing )
2024-12-12 02:55:43 +00:00
if should_generate :
2024-12-11 02:13:21 +00:00
comparison_inputs . append ( ( text , prompt , language , out_path_comparison ) )
2024-12-18 01:33:04 +00:00
metrics_inputs . append ( ( dataset_name , text , language , out_path_comparison , prompt , reference , metrics_path ) )
2024-12-12 02:55:43 +00:00
should_generate = ( args . skip_existing and not out_path . exists ( ) ) or not ( args . skip_existing )
2024-10-18 18:19:36 +00:00
2024-12-12 02:55:43 +00:00
if should_generate :
2024-12-11 02:13:21 +00:00
inputs . append ( ( text , prompt , language , out_path ) )
2024-12-18 01:33:04 +00:00
metrics_inputs . append ( ( dataset_name , text , language , out_path , prompt , reference , metrics_path ) )
2024-10-12 16:27:55 +00:00
2024-12-18 01:33:04 +00:00
outputs . append ( ( dataset_name , samples ) )
2024-10-12 16:27:55 +00:00
2024-12-08 01:21:05 +00:00
if inputs :
process_batch ( tts , inputs , sampling_kwargs | ( comparison_kwargs [ " disabled " ] if args . comparison else { } ) )
if comparison_inputs :
process_batch ( tts , comparison_inputs , sampling_kwargs | ( comparison_kwargs [ " enabled " ] if args . comparison else { } ) )
2024-07-20 01:49:40 +00:00
2024-12-11 02:13:21 +00:00
metrics_map = { }
2024-12-18 01:33:04 +00:00
for dataset_name , text , language , out_path , prompt_path , reference_path , metrics_path in tqdm ( metrics_inputs , desc = " Calculating metrics " ) :
2024-12-12 02:55:43 +00:00
calculate = not metrics_path . exists ( ) or ( metrics_path . stat ( ) . st_mtime < out_path . stat ( ) . st_mtime )
if calculate :
2024-12-19 05:43:11 +00:00
# computes based on word transcriptions outright
wer_score , cer_score = wer ( out_path , text , language = language , device = tts . device , dtype = tts . dtype , model_name = args . transcription_model , phonemize = False )
# compute on words as well, but does not normalize
wer_un_score , cer_un_score = wer ( out_path , text , language = language , device = tts . device , dtype = tts . dtype , model_name = args . transcription_model , phonemize = False , normalize = False )
# computes on phonemes instead
pwer_score , per_score = wer ( out_path , text , language = language , device = tts . device , dtype = tts . dtype , model_name = args . transcription_model , phonemize = True )
2024-12-12 04:45:38 +00:00
sim_o_score = sim_o ( out_path , prompt_path , device = tts . device , dtype = tts . dtype , model_name = args . speaker_similarity_model )
2024-12-12 02:55:43 +00:00
2024-12-19 05:43:11 +00:00
metrics = { " wer " : wer_score , " cer " : cer_score , " sim-o " : sim_o_score , " per " : per_score , " pwer " : pwer_score , " wer_un " : wer_un_score , " cer_un " : cer_un_score }
2024-12-12 02:55:43 +00:00
json_write ( metrics , metrics_path )
else :
metrics = json_read ( metrics_path )
2024-12-19 05:43:11 +00:00
wer_score , cer_score , per_score , sim_o_score = metrics [ " wer " ] , metrics [ " cer " ] , metrics [ " per " ] , metrics [ " sim-o " ]
2024-12-12 02:55:43 +00:00
2024-12-18 01:33:04 +00:00
if dataset_name not in metrics_map :
metrics_map [ dataset_name ] = { }
2024-12-20 01:08:57 +00:00
metrics_map [ dataset_name ] [ out_path ] = ( wer_score , cer_score , per_score , sim_o_score )
2024-12-11 02:13:21 +00:00
2024-12-08 01:21:05 +00:00
# collate entries into HTML
2024-12-18 01:33:04 +00:00
tables = [ ]
for dataset_name , samples in outputs :
2024-12-19 05:43:11 +00:00
table = " \t \t <h3>$ {DATASET_NAME} </h3> \n \t \t <p><b>Average WER:</b> $ {WER} <br><b>Average CER:</b> $ {CER} <br><b>Average PER:</b> $ {PER} <br><b>Average SIM-O:</b> $ { SIM-O}<br></p> \n \t \t <table> \n \t \t \t <thead> \n \t \t \t \t <tr> \n \t \t \t \t \t <th>Text</th> \n \t \t \t \t \t <th>WER↓</th> \n \t \t \t \t \t <th>CER↓</th> \n \t \t \t \t \t <th>SIM-O↑</th> \n \t \t \t \t \t <th>Prompt</th> \n \t \t \t \t \t <th>Our VALL-E</th> \n \t \t \t \t \t <!--th>Original VALL-E</th--> \n \t \t \t \t \t <!--th>F5-TTS</th--> \n \t \t \t \t \t <th>Ground Truth</th> \n \t \t \t \t </tr> \n \t \t \t </thead> \n \t \t \t <tbody>$ {SAMPLES} </tbody> \n \t \t </table> "
2024-07-20 01:49:40 +00:00
samples = [
2024-07-22 00:12:03 +00:00
f ' \n \t \t \t <tr> \n \t \t \t \t <td> { text } </td> ' +
2024-12-11 02:13:21 +00:00
" " . join ( [
2024-12-19 05:43:11 +00:00
f ' \n \t \t \t \t <td> { metrics_map [ dataset_name ] [ audios [ 1 ] ] [ 0 ] : .3f } </td><td> { metrics_map [ dataset_name ] [ audios [ 1 ] ] [ 1 ] : .3f } </td><td> { metrics_map [ dataset_name ] [ audios [ 1 ] ] [ 2 ] : .3f } </td><td> { metrics_map [ dataset_name ] [ audios [ 1 ] ] [ 3 ] : .3f } </td> '
2024-12-11 02:13:21 +00:00
] ) +
2024-07-20 01:49:40 +00:00
" " . join ( [
2024-07-22 00:12:03 +00:00
f ' \n \t \t \t \t <td><audio controls= " controls " preload= " none " ><source src= " { str ( audio ) . replace ( str ( args . demo_dir ) , args . audio_path_root ) if args . audio_path_root else encode ( audio ) } " /></audio></td> '
2024-07-20 01:49:40 +00:00
for audio in audios
] ) +
2024-07-22 00:12:03 +00:00
' \n \t \t \t </tr> '
2024-07-20 01:49:40 +00:00
for text , audios in samples
]
2024-07-22 00:12:03 +00:00
# write audio into template
2024-12-18 01:33:04 +00:00
table = table . replace ( " $ {WER} " , f ' { mean ( [ metrics [ 0 ] for metrics in metrics_map [ dataset_name ] . values ( ) ] ) : .3f } ' )
table = table . replace ( " $ {CER} " , f ' { mean ( [ metrics [ 1 ] for metrics in metrics_map [ dataset_name ] . values ( ) ] ) : .3f } ' )
2024-12-19 05:43:11 +00:00
table = table . replace ( " $ {PER} " , f ' { mean ( [ metrics [ 2 ] for metrics in metrics_map [ dataset_name ] . values ( ) ] ) : .3f } ' )
table = table . replace ( " $ { SIM-O} " , f ' { mean ( [ metrics [ 3 ] for metrics in metrics_map [ dataset_name ] . values ( ) ] ) : .3f } ' )
2024-12-18 01:33:04 +00:00
table = table . replace ( " $ {DATASET_NAME} " , dataset_name )
table = table . replace ( " $ {SAMPLES} " , " \n " . join ( samples ) )
tables . append ( table )
2024-12-11 03:00:51 +00:00
2024-12-18 01:33:04 +00:00
html = html . replace ( " $ {TABLES} " , " \n " . join ( tables ) )
2024-07-20 01:49:40 +00:00
2024-12-08 01:21:05 +00:00
if args . comparison :
disabled , enabled = comparison_kwargs [ " titles " ]
if args . random_prompts :
html = html . replace ( " <th>Our VALL-E</th> \n \t \t \t \t \t <th>Ground Truth</th> " , f " <th>Our VALL-E ( { disabled } )</th> \n \t \t \t \t \t <th>Our VALL-E ( { enabled } )</th> " )
else :
html = html . replace ( " <th>Our VALL-E</th> " , f " <th>Our VALL-E ( { disabled } )</th> \n \t \t \t \t \t <th>Our VALL-E ( { enabled } )</th> " )
2024-10-10 18:52:37 +00:00
2024-07-22 00:12:03 +00:00
# write demo page
2024-10-11 00:40:01 +00:00
open ( args . demo_dir / args . output_filename , " w " , encoding = " utf-8 " ) . write ( html )
2024-07-20 01:49:40 +00:00
if __name__ == " __main__ " :
main ( )