sanity checks (and I realized that the model actually had langs set to 4 in the yaml for KO/ZH so................
This commit is contained in:
parent
7617b6485f
commit
8838babcba
52
scripts/process_nscripter.py
Normal file
52
scripts/process_nscripter.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
"""
|
||||||
|
Handles processing NScripter's 0.u file to clean up the pile of audio clips it has
|
||||||
|
|
||||||
|
* to-do: also grab transcriptions
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
import shutil
|
||||||
|
import torchaudio
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def process(
|
||||||
|
input_file=Path("./assets/0.u"),
|
||||||
|
wav_dir=Path("./arc/"),
|
||||||
|
output_dir=Path("./dataset/"),
|
||||||
|
):
|
||||||
|
file = open(input_file, encoding='utf-8').read()
|
||||||
|
|
||||||
|
names = {}
|
||||||
|
aliases = {}
|
||||||
|
lines = file.split('\n')
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if not line.startswith('stralias'):
|
||||||
|
continue
|
||||||
|
# ick
|
||||||
|
try:
|
||||||
|
key, path = re.findall(r'^stralias (.+?),"(.+?)"$', line)[0]
|
||||||
|
name = key.split("_")[0]
|
||||||
|
if name not in names:
|
||||||
|
(output_dir / name).mkdir(parents=True, exist_ok=True)
|
||||||
|
names[name] = True
|
||||||
|
|
||||||
|
aliases[key] = Path(path)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for k, v in aliases.items():
|
||||||
|
name = k.split("_")[0]
|
||||||
|
|
||||||
|
|
||||||
|
print(aliases)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
process()
|
|
@ -18,10 +18,10 @@ from tqdm.auto import tqdm
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
def process(
|
def process(
|
||||||
input_dir=Path("./seedtts_testset/en/"),
|
input_dir=Path("./seedtts_testset/zh/"),
|
||||||
list_name="./meta.lst",
|
list_name="./hardcase.lst",
|
||||||
wav_dir="./wavs/",
|
wav_dir="./wavs/",
|
||||||
output_dir=Path("./dataset/seed-tts-eval-en/"),
|
output_dir=Path("./dataset/seed-tts-eval-hard/"),
|
||||||
):
|
):
|
||||||
language = "auto"
|
language = "auto"
|
||||||
|
|
||||||
|
@ -46,7 +46,11 @@ def process(
|
||||||
open( output_dir / filename / "prompt.txt", "w", encoding="utf-8" ).write( text )
|
open( output_dir / filename / "prompt.txt", "w", encoding="utf-8" ).write( text )
|
||||||
open( output_dir / filename / "language.txt", "w", encoding="utf-8" ).write( language )
|
open( output_dir / filename / "language.txt", "w", encoding="utf-8" ).write( language )
|
||||||
|
|
||||||
shutil.copy((input_dir / wav_dir / filename).with_suffix(".wav"), output_dir / filename / "reference.wav" )
|
reference_wav = (input_dir / wav_dir / filename).with_suffix(".wav")
|
||||||
|
if not reference_wav.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
shutil.copy(reference_wav, output_dir / filename / "reference.wav" )
|
||||||
shutil.copy(input_dir / prompt_wav, output_dir / filename / "prompt.wav" )
|
shutil.copy(input_dir / prompt_wav, output_dir / filename / "prompt.wav" )
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -192,7 +192,7 @@ def normalize_text(text, language="auto", full=True):
|
||||||
return text
|
return text
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def get_random_prompts( validation=False, min_length=0, tokenized=False, source_path=Path("./data/harvard_sentences.txt") ):
|
def get_random_prompts( validation=False, min_length=0, tokenized=False, source_path=Path("./data/tongue_twisters.txt") ):
|
||||||
duration_range = [ 5.5, 12.0 ] # to-do: pull from cfg.dataset.duration_range
|
duration_range = [ 5.5, 12.0 ] # to-do: pull from cfg.dataset.duration_range
|
||||||
sentences = [
|
sentences = [
|
||||||
"The birch canoe slid on the smooth planks.",
|
"The birch canoe slid on the smooth planks.",
|
||||||
|
|
|
@ -289,10 +289,10 @@ def main():
|
||||||
|
|
||||||
# only add the existing librispeech validation dataset if i'm doing validation so I can stop commenting this out
|
# only add the existing librispeech validation dataset if i'm doing validation so I can stop commenting this out
|
||||||
if not args.dataset_dir_name:
|
if not args.dataset_dir_name:
|
||||||
samples_dirs["librispeech"] = args.demo_dir / "librispeech"
|
samples_dirs["librispeech"] = args.demo_dir / "datasets" / "librispeech"
|
||||||
else:
|
else:
|
||||||
if "validation" in args.dataset_dir_name:
|
if "validation" in args.dataset_dir_name:
|
||||||
samples_dirs["librispeech"] = args.demo_dir / "librispeech"
|
samples_dirs["librispeech"] = args.demo_dir / "datasets" / "librispeech"
|
||||||
|
|
||||||
# automatically pull from anything under the dataset dir
|
# automatically pull from anything under the dataset dir
|
||||||
if args.dataset_dir_name.endswith("/*"):
|
if args.dataset_dir_name.endswith("/*"):
|
||||||
|
@ -444,7 +444,7 @@ def main():
|
||||||
if dataset_name not in metrics_map:
|
if dataset_name not in metrics_map:
|
||||||
metrics_map[dataset_name] = {}
|
metrics_map[dataset_name] = {}
|
||||||
|
|
||||||
metrics_map[dataset_name][out_path] = (wer_score, cer_score, sim_o_score)
|
metrics_map[dataset_name][out_path] = (wer_score, cer_score, per_score, sim_o_score)
|
||||||
|
|
||||||
# collate entries into HTML
|
# collate entries into HTML
|
||||||
tables = []
|
tables = []
|
||||||
|
|
|
@ -13,15 +13,13 @@ from .utils.io import torch_save, torch_load
|
||||||
# *will* require retraining because the classifier is in one contiguous space, and proms are NOT summed
|
# *will* require retraining because the classifier is in one contiguous space, and proms are NOT summed
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_to_hf( state_dict, config = None, save_path = None ):
|
def convert_to_hf( state_dict, config = None, save_path = None ):
|
||||||
# to-do: infer all of this from the existing state_dict, should be easy by checking shape
|
n_text_tokens, model_dim = state_dict['module']['text_emb.weight'].shape
|
||||||
model_dim = 1024
|
|
||||||
|
|
||||||
n_text_tokens = 256
|
n_audio_tokens = state_dict['module']['proms_emb.embeddings.0.weight'].shape[0]
|
||||||
n_audio_tokens = 1024
|
n_resp_levels = state_dict['module']['rvq_l_emb.weight'].shape[0]
|
||||||
n_resp_levels = 8
|
|
||||||
n_len_tokens = 11
|
n_len_tokens = 11
|
||||||
n_lang_tokens = 4
|
n_lang_tokens = state_dict['module']['langs_emb.weight'].shape[0]
|
||||||
n_task_tokens = 9
|
n_task_tokens = state_dict['module']['tasks_emb.weight'].shape[0]
|
||||||
|
|
||||||
# the new tokenizer to use
|
# the new tokenizer to use
|
||||||
tokenizer_append = {}
|
tokenizer_append = {}
|
||||||
|
@ -45,6 +43,8 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
||||||
"ja",
|
"ja",
|
||||||
"de",
|
"de",
|
||||||
"fr",
|
"fr",
|
||||||
|
"zh",
|
||||||
|
"ko",
|
||||||
]
|
]
|
||||||
task_map = [
|
task_map = [
|
||||||
"tts",
|
"tts",
|
||||||
|
@ -100,8 +100,8 @@ def convert_to_hf( state_dict, config = None, save_path = None ):
|
||||||
token_start = token_end
|
token_start = token_end
|
||||||
token_end += l_tokens[2] // 2
|
token_end += l_tokens[2] // 2
|
||||||
embedding.weight[token_start:token_end] = state_dict['module'][f'resps_emb.embeddings.8.weight']
|
embedding.weight[token_start:token_end] = state_dict['module'][f'resps_emb.embeddings.8.weight']
|
||||||
classifier.weight[token_start:token_end] = state_dict['module']['classifiers.proj.8.weight']
|
classifier.weight[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.weight']
|
||||||
classifier.bias[token_start:token_end] = state_dict['module']['classifiers.proj.8.bias']
|
classifier.bias[token_start:token_end-1] = state_dict['module']['classifiers.proj.8.bias']
|
||||||
for t in range(n_audio_tokens):
|
for t in range(n_audio_tokens):
|
||||||
tokenizer_append[f'<NAR:0:0:{t}>'] = token_start + t
|
tokenizer_append[f'<NAR:0:0:{t}>'] = token_start + t
|
||||||
tokenizer_append[f'<NAR:0:0:STOP>'] = token_start + 1024
|
tokenizer_append[f'<NAR:0:0:STOP>'] = token_start + 1024
|
||||||
|
|
Loading…
Reference in New Issue
Block a user