sanity checks (and I realized that the model actually had langs set to 4 in the yaml for KO/ZH so................
This commit is contained in:
parent
7617b6485f
commit
8838babcba
52
scripts/process_nscripter.py
Normal file
52
scripts/process_nscripter.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
"""
|
||||
Handles processing NScripter's 0.u file to clean up the pile of audio clips it has
|
||||
|
||||
* to-do: also grab transcriptions
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import argparse
|
||||
import torch
|
||||
import shutil
|
||||
import torchaudio
|
||||
import numpy as np
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from pathlib import Path
|
||||
|
||||
def process(
|
||||
input_file=Path("./assets/0.u"),
|
||||
wav_dir=Path("./arc/"),
|
||||
output_dir=Path("./dataset/"),
|
||||
):
|
||||
file = open(input_file, encoding='utf-8').read()
|
||||
|
||||
names = {}
|
||||
aliases = {}
|
||||
lines = file.split('\n')
|
||||
|
||||
for line in lines:
|
||||
if not line.startswith('stralias'):
|
||||
continue
|
||||
# ick
|
||||
try:
|
||||
key, path = re.findall(r'^stralias (.+?),"(.+?)"$', line)[0]
|
||||
name = key.split("_")[0]
|
||||
if name not in names:
|
||||
(output_dir / name).mkdir(parents=True, exist_ok=True)
|
||||
names[name] = True
|
||||
|
||||
aliases[key] = Path(path)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
for k, v in aliases.items():
|
||||
name = k.split("_")[0]
|
||||
|
||||
|
||||
print(aliases)
|
||||
|
||||
if __name__ == "__main__":
|
||||
process()
|
|
@ -18,10 +18,10 @@ from tqdm.auto import tqdm
|
|||
from pathlib import Path
|
||||
|
||||
def process(
|
||||
input_dir=Path("./seedtts_testset/en/"),
|
||||
list_name="./meta.lst",
|
||||
input_dir=Path("./seedtts_testset/zh/"),
|
||||
list_name="./hardcase.lst",
|
||||
wav_dir="./wavs/",
|
||||
output_dir=Path("./dataset/seed-tts-eval-en/"),
|
||||
output_dir=Path("./dataset/seed-tts-eval-hard/"),
|
||||
):
|
||||
language = "auto"
|
||||
|
||||
|
@ -46,7 +46,11 @@ def process(
|
|||
open( output_dir / filename / "prompt.txt", "w", encoding="utf-8" ).write( text )
|
||||
open( output_dir / filename / "language.txt", "w", encoding="utf-8" ).write( language )
|
||||
|
||||
shutil.copy((input_dir / wav_dir / filename).with_suffix(".wav"), output_dir / filename / "reference.wav" )
|
||||
reference_wav = (input_dir / wav_dir / filename).with_suffix(".wav")
|
||||
if not reference_wav.exists():
|
||||
continue
|
||||
|
||||
shutil.copy(reference_wav, output_dir / filename / "reference.wav" )
|
||||
shutil.copy(input_dir / prompt_wav, output_dir / filename / "prompt.wav" )
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -192,7 +192,7 @@ def normalize_text(text, language="auto", full=True):
|
|||
return text
|
||||
|
||||
@cache
|
||||
def get_random_prompts( validation=False, min_length=0, tokenized=False, source_path=Path("./data/harvard_sentences.txt") ):
|
||||
def get_random_prompts( validation=False, min_length=0, tokenized=False, source_path=Path("./data/tongue_twisters.txt") ):
|
||||
duration_range = [ 5.5, 12.0 ] # to-do: pull from cfg.dataset.duration_range
|
||||
sentences = [
|
||||
"The birch canoe slid on the smooth planks.",
|
||||
|
|
|
@ -289,10 +289,10 @@ def main():
|
|||
|
||||
# only add the existing librispeech validation dataset if i'm doing validation so I can stop commenting this out
|
||||
if not args.dataset_dir_name:
|
||||
samples_dirs["librispeech"] = args.demo_dir / "librispeech"
|
||||
samples_dirs["librispeech"] = args.demo_dir / "datasets" / "librispeech"
|
||||
else:
|
||||
if "validation" in args.dataset_dir_name:
|
||||
samples_dirs["librispeech"] = args.demo_dir / "librispeech"
|
||||
samples_dirs["librispeech"] = args.demo_dir / "datasets" / "librispeech"
|
||||
|
||||
# automatically pull from anything under the dataset dir
|
||||
if args.dataset_dir_name.endswith("/*"):
|
||||
|
@ -444,7 +444,7 @@ def main():
|
|||
if dataset_name not in metrics_map:
|
||||
metrics_map[dataset_name] = {}
|
||||
|
||||
metrics_map[dataset_name][out_path] = (wer_score, cer_score, sim_o_score)
|
||||
metrics_map[dataset_name][out_path] = (wer_score, cer_score, per_score, sim_o_score)
|
||||
|
||||
# collate entries into HTML
|
||||
tables = []
|
||||
|
|
|
@ -13,15 +13,13 @@ from .utils.io import torch_save, torch_load
|
|||
# *will* require retraining because the classifier is in one contiguous space, and proms are NOT summed
|
||||
@torch.no_grad()
|
||||
def convert_to_hf( state_dict, config = None, save_path = None ):
|
||||
# to-do: infer all of this from the existing state_dict, should be easy by checking shape
|
||||
model_dim = 1024
|
||||
n_text_tokens, model_dim = state_dict['module']['text_emb.weight'].shape
|
||||
|
||||
n_text_tokens = 256
|
||||
n_audio_tokens = 1024
|
||||
n_resp_levels = 8
|
||||
n_audio_tokens = state_dict['module']['proms_emb.embeddings.0.weight'].shape[0]
|
||||
n_resp_levels = state_dict['module']['rvq_l_emb.weight'].shape[0]
|
||||
n_len_tokens = 11
|
||||
n_lang_tokens = 4
|
||||
n_task_tokens = 9
|
||||
n_lang_tokens = state_dict['module']['langs_emb.weight'].shape[0]
|
||||
n_task_tokens = state_dict['module']['tasks_emb.weight'].shape[0]
|
||||
|
||||
# the new tokenizer to use
|
||||
tokenizer_append = {}
|
||||
|
@ -45,6 +43,8 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
"ja",
|
||||
"de",
|
||||
"fr",
|
||||
"zh",
|
||||
"ko",
|
||||
]
|
||||
task_map = [
|
||||
"tts",
|
||||
|
@ -100,8 +100,8 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
|||
token_start = token_end
|
||||
token_end += l_tokens[2] // 2
|
||||
embedding.weight[token_start:token_end] = state_dict['module'][f'resps_emb.embeddings.8.weight']
|
||||
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.8.weight']
|
||||
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.8.bias']
|
||||
classifier.weight[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.weight']
|
||||
classifier.bias[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.bias']
|
||||
for t in range(n_audio_tokens):
|
||||
tokenizer_append[f'<NAR:0:0:{t}>'] = token_start + t
|
||||
tokenizer_append[f'<NAR:0:0:STOP>'] = token_start + 1024
|
||||
|
|
Loading…
Reference in New Issue
Block a user