sanity checks (and I realized that the model actually had langs set to 4 in the yaml for KO/ZH so................

This commit is contained in:
mrq 2024-12-19 19:08:57 -06:00
parent 7617b6485f
commit 8838babcba
5 changed files with 73 additions and 17 deletions

View File

@ -0,0 +1,52 @@
"""
Handles processing NScripter's 0.u file to clean up the pile of audio clips it has
* to-do: also grab transcriptions
"""
import os
import re
import json
import argparse
import torch
import shutil
import torchaudio
import numpy as np
from tqdm.auto import tqdm
from pathlib import Path
def process(
input_file=Path("./assets/0.u"),
wav_dir=Path("./arc/"),
output_dir=Path("./dataset/"),
):
file = open(input_file, encoding='utf-8').read()
names = {}
aliases = {}
lines = file.split('\n')
for line in lines:
if not line.startswith('stralias'):
continue
# ick
try:
key, path = re.findall(r'^stralias (.+?),"(.+?)"$', line)[0]
name = key.split("_")[0]
if name not in names:
(output_dir / name).mkdir(parents=True, exist_ok=True)
names[name] = True
aliases[key] = Path(path)
except Exception as e:
pass
for k, v in aliases.items():
name = k.split("_")[0]
print(aliases)
if __name__ == "__main__":
process()

View File

@ -18,10 +18,10 @@ from tqdm.auto import tqdm
from pathlib import Path
def process(
input_dir=Path("./seedtts_testset/en/"),
list_name="./meta.lst",
input_dir=Path("./seedtts_testset/zh/"),
list_name="./hardcase.lst",
wav_dir="./wavs/",
output_dir=Path("./dataset/seed-tts-eval-en/"),
output_dir=Path("./dataset/seed-tts-eval-hard/"),
):
language = "auto"
@ -46,7 +46,11 @@ def process(
open( output_dir / filename / "prompt.txt", "w", encoding="utf-8" ).write( text )
open( output_dir / filename / "language.txt", "w", encoding="utf-8" ).write( language )
shutil.copy((input_dir / wav_dir / filename).with_suffix(".wav"), output_dir / filename / "reference.wav" )
reference_wav = (input_dir / wav_dir / filename).with_suffix(".wav")
if not reference_wav.exists():
continue
shutil.copy(reference_wav, output_dir / filename / "reference.wav" )
shutil.copy(input_dir / prompt_wav, output_dir / filename / "prompt.wav" )
if __name__ == "__main__":

View File

@ -192,7 +192,7 @@ def normalize_text(text, language="auto", full=True):
return text
@cache
def get_random_prompts( validation=False, min_length=0, tokenized=False, source_path=Path("./data/harvard_sentences.txt") ):
def get_random_prompts( validation=False, min_length=0, tokenized=False, source_path=Path("./data/tongue_twisters.txt") ):
duration_range = [ 5.5, 12.0 ] # to-do: pull from cfg.dataset.duration_range
sentences = [
"The birch canoe slid on the smooth planks.",

View File

@ -289,10 +289,10 @@ def main():
# only add the existing librispeech validation dataset if i'm doing validation so I can stop commenting this out
if not args.dataset_dir_name:
samples_dirs["librispeech"] = args.demo_dir / "librispeech"
samples_dirs["librispeech"] = args.demo_dir / "datasets" / "librispeech"
else:
if "validation" in args.dataset_dir_name:
samples_dirs["librispeech"] = args.demo_dir / "librispeech"
samples_dirs["librispeech"] = args.demo_dir / "datasets" / "librispeech"
# automatically pull from anything under the dataset dir
if args.dataset_dir_name.endswith("/*"):
@ -444,7 +444,7 @@ def main():
if dataset_name not in metrics_map:
metrics_map[dataset_name] = {}
metrics_map[dataset_name][out_path] = (wer_score, cer_score, sim_o_score)
metrics_map[dataset_name][out_path] = (wer_score, cer_score, per_score, sim_o_score)
# collate entries into HTML
tables = []

View File

@ -13,15 +13,13 @@ from .utils.io import torch_save, torch_load
# *will* require retraining because the classifier is in one contiguous space, and proms are NOT summed
@torch.no_grad()
def convert_to_hf( state_dict, config = None, save_path = None ):
# to-do: infer all of this from the existing state_dict, should be easy by checking shape
model_dim = 1024
n_text_tokens, model_dim = state_dict['module']['text_emb.weight'].shape
n_text_tokens = 256
n_audio_tokens = 1024
n_resp_levels = 8
n_audio_tokens = state_dict['module']['proms_emb.embeddings.0.weight'].shape[0]
n_resp_levels = state_dict['module']['rvq_l_emb.weight'].shape[0]
n_len_tokens = 11
n_lang_tokens = 4
n_task_tokens = 9
n_lang_tokens = state_dict['module']['langs_emb.weight'].shape[0]
n_task_tokens = state_dict['module']['tasks_emb.weight'].shape[0]
# the new tokenizer to use
tokenizer_append = {}
@ -45,6 +43,8 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
"ja",
"de",
"fr",
"zh",
"ko",
]
task_map = [
"tts",
@ -100,8 +100,8 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
token_start = token_end
token_end += l_tokens[2] // 2
embedding.weight[token_start:token_end] = state_dict['module'][f'resps_emb.embeddings.8.weight']
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.8.weight']
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.8.bias']
classifier.weight[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.weight']
classifier.bias[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.bias']
for t in range(n_audio_tokens):
tokenizer_append[f'<NAR:0:0:{t}>'] = token_start + t
tokenizer_append[f'<NAR:0:0:STOP>'] = token_start + 1024