From 8838babcba1f180dbb72e9c50b1688bea491dfd7 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 19 Dec 2024 19:08:57 -0600 Subject: [PATCH] sanity checks (and I realized that the model actually had langs set to 4 in the yaml for KO/ZH so................ --- scripts/process_nscripter.py | 52 ++++++++++++++++++++++++++++++++++++ scripts/process_seed-tts.py | 12 ++++++--- vall_e/data.py | 2 +- vall_e/demo.py | 6 ++--- vall_e/export.py | 18 ++++++------- 5 files changed, 73 insertions(+), 17 deletions(-) create mode 100644 scripts/process_nscripter.py diff --git a/scripts/process_nscripter.py b/scripts/process_nscripter.py new file mode 100644 index 0000000..c624377 --- /dev/null +++ b/scripts/process_nscripter.py @@ -0,0 +1,52 @@ +""" +Handles processing NScripter's 0.u file to clean up the pile of audio clips it has + +* to-do: also grab transcriptions +""" + +import os +import re +import json +import argparse +import torch +import shutil +import torchaudio +import numpy as np + +from tqdm.auto import tqdm +from pathlib import Path + +def process( + input_file=Path("./assets/0.u"), + wav_dir=Path("./arc/"), + output_dir=Path("./dataset/"), +): + file = open(input_file, encoding='utf-8').read() + + names = {} + aliases = {} + lines = file.split('\n') + + for line in lines: + if not line.startswith('stralias'): + continue + # ick + try: + key, path = re.findall(r'^stralias (.+?),"(.+?)"$', line)[0] + name = key.split("_")[0] + if name not in names: + (output_dir / name).mkdir(parents=True, exist_ok=True) + names[name] = True + + aliases[key] = Path(path) + except Exception as e: + pass + + for k, v in aliases.items(): + name = k.split("_")[0] + + + print(aliases) + +if __name__ == "__main__": + process() \ No newline at end of file diff --git a/scripts/process_seed-tts.py b/scripts/process_seed-tts.py index c73cd58..712014a 100644 --- a/scripts/process_seed-tts.py +++ b/scripts/process_seed-tts.py @@ -18,10 +18,10 @@ from tqdm.auto import tqdm from pathlib import Path def process( - input_dir=Path("./seedtts_testset/en/"), - list_name="./meta.lst", + input_dir=Path("./seedtts_testset/zh/"), + list_name="./hardcase.lst", wav_dir="./wavs/", - output_dir=Path("./dataset/seed-tts-eval-en/"), + output_dir=Path("./dataset/seed-tts-eval-hard/"), ): language = "auto" @@ -46,7 +46,11 @@ def process( open( output_dir / filename / "prompt.txt", "w", encoding="utf-8" ).write( text ) open( output_dir / filename / "language.txt", "w", encoding="utf-8" ).write( language ) - shutil.copy((input_dir / wav_dir / filename).with_suffix(".wav"), output_dir / filename / "reference.wav" ) + reference_wav = (input_dir / wav_dir / filename).with_suffix(".wav") + if not reference_wav.exists(): + continue + + shutil.copy(reference_wav, output_dir / filename / "reference.wav" ) shutil.copy(input_dir / prompt_wav, output_dir / filename / "prompt.wav" ) if __name__ == "__main__": diff --git a/vall_e/data.py b/vall_e/data.py index 9fcfe98..5b9a395 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -192,7 +192,7 @@ def normalize_text(text, language="auto", full=True): return text @cache -def get_random_prompts( validation=False, min_length=0, tokenized=False, source_path=Path("./data/harvard_sentences.txt") ): +def get_random_prompts( validation=False, min_length=0, tokenized=False, source_path=Path("./data/tongue_twisters.txt") ): duration_range = [ 5.5, 12.0 ] # to-do: pull from cfg.dataset.duration_range sentences = [ "The birch canoe slid on the smooth planks.", diff --git a/vall_e/demo.py b/vall_e/demo.py index 16ab412..e9d5bb6 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -289,10 +289,10 @@ def main(): # only add the existing librispeech validation dataset if i'm doing validation so I can stop commenting this out if not args.dataset_dir_name: - samples_dirs["librispeech"] = args.demo_dir / "librispeech" + samples_dirs["librispeech"] = args.demo_dir / "datasets" / "librispeech" else: if "validation" in args.dataset_dir_name: - samples_dirs["librispeech"] = args.demo_dir / "librispeech" + samples_dirs["librispeech"] = args.demo_dir / "datasets" / "librispeech" # automatically pull from anything under the dataset dir if args.dataset_dir_name.endswith("/*"): @@ -444,7 +444,7 @@ def main(): if dataset_name not in metrics_map: metrics_map[dataset_name] = {} - metrics_map[dataset_name][out_path] = (wer_score, cer_score, sim_o_score) + metrics_map[dataset_name][out_path] = (wer_score, cer_score, per_score, sim_o_score) # collate entries into HTML tables = [] diff --git a/vall_e/export.py b/vall_e/export.py index a0212b2..d8e43e0 100755 --- a/vall_e/export.py +++ b/vall_e/export.py @@ -13,15 +13,13 @@ from .utils.io import torch_save, torch_load # *will* require retraining because the classifier is in one contiguous space, and proms are NOT summed @torch.no_grad() def convert_to_hf( state_dict, config = None, save_path = None ): - # to-do: infer all of this from the existing state_dict, should be easy by checking shape - model_dim = 1024 + n_text_tokens, model_dim = state_dict['module']['text_emb.weight'].shape - n_text_tokens = 256 - n_audio_tokens = 1024 - n_resp_levels = 8 + n_audio_tokens = state_dict['module']['proms_emb.embeddings.0.weight'].shape[0] + n_resp_levels = state_dict['module']['rvq_l_emb.weight'].shape[0] n_len_tokens = 11 - n_lang_tokens = 4 - n_task_tokens = 9 + n_lang_tokens = state_dict['module']['langs_emb.weight'].shape[0] + n_task_tokens = state_dict['module']['tasks_emb.weight'].shape[0] # the new tokenizer to use tokenizer_append = {} @@ -45,6 +43,8 @@ def convert_to_hf( state_dict, config = None, save_path = None ): "ja", "de", "fr", + "zh", + "ko", ] task_map = [ "tts", @@ -100,8 +100,8 @@ def convert_to_hf( state_dict, config = None, save_path = None ): token_start = token_end token_end += l_tokens[2] // 2 embedding.weight[token_start:token_end] = state_dict['module'][f'resps_emb.embeddings.8.weight'] - classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.8.weight'] - classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.8.bias'] + classifier.weight[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.weight'] + classifier.bias[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.bias'] for t in range(n_audio_tokens): tokenizer_append[f''] = token_start + t tokenizer_append[f''] = token_start + 1024