dataset preparation script updates, caved and am using HF tokenizer now

This commit is contained in:
mrq 2024-04-21 14:49:18 -05:00
parent a8ffa88844
commit 071fb97777
8 changed files with 211 additions and 107 deletions

View File

@ -1,51 +1,23 @@
dataset:
training: []
validation: []
noise: []
speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
use_hdf5: True
use_metadata: True
hdf5_flag: r
validate: True
workers: 2
cache: True
phones_range: [4, 512]
duration_range: [1.0, 32.0]
random_utterance: 1.0
max_prompts: 3
prompt_duration: 6.0
sample_type: speaker
tasks_list: [ "tts" ] # , [ "tts", "tts-c", "ns", "sr", "tse", "cse", "nse", "tts"]
models:
_prom_levels: 8
_max_levels: 8
_models:
- name: "ar+nar"
size: "full"
resp_levels: 8
prom_levels: 8
tasks: 8
arch_type: "retnet"
training: True
version: 3
- name: "ar+nar"
size: "full"
resp_levels: 8
prom_levels: 8
tasks: 8
langs: 2
tones: 1
arch_type: "retnet"
training: True
version: 3
hyperparameters:
batch_size: 8
gradient_accumulation_steps: 32
gradient_clipping: 100
batch_size: 4
gradient_accumulation_steps: 4
gradient_clipping: 10
optimizer: Prodigy
optimizer: Adagrad
torch_optimizer: True
learning_rate: 0.0625
learning_rate: 1.0e-2
scheduler_type: ""
#scheduler_type: OneCycle
@ -67,22 +39,24 @@ hyperparameters:
# decay_mom_rate: 0.0
evaluation:
batch_size: 16
frequency: 250
size: 16
batch_size: 8
frequency: 10000
size: 8
steps: 450
steps: 500
ar_temperature: 0.95
nar_temperature: 0.25
load_disabled_engines: True
trainer:
no_logger: True
iterations: 1_000_000
save_tag: step
save_on_oom: True
save_on_quit: True
save_frequency: 100
save_frequency: 250
export_on_save: True
keep_last_checkpoints: 4
@ -91,33 +65,82 @@ trainer:
load_disabled_engines: False
#load_state_dict: True
#strict_loading: False
strict_loading: False
#load_tag: "9500"
#load_states: False
#restart_step_count: True
gc_mode: None # "global_step"
weight_dtype: bfloat16
weight_dtype: float32
amp: False
backend: deepspeed
deepspeed:
inferencing: True
zero_optimization_level: 0
use_compression_training: True
use_compression_training: False
activation_checkpointing: True
load_webui: True
inference:
use_vocos: True
backend: deepspeed
audio_backend: "dac"
normalize: False
weight_dtype: bfloat16
weight_dtype: float32
amp: False
bitsandbytes:
enabled: False
injects: True
linear: True
embedding: True
injects: False
replace: False
linear: False
embedding: False
bitnet: False
fp8:
enabled: False
backend: "te"
experimental: True
dataset:
speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
speaker_group_getter: "lambda p: f'{p.parts[-3]}'"
speaker_languages:
ja: []
use_hdf5: True
use_metadata: True
hdf5_flag: r
validate: True
workers: 8
cache: True
#phones_range: [4, 512]
#duration_range: [1.0, 32.0]
phones_range: [0, 512]
duration_range: [0.0, 64.0]
random_utterance: 1.0
max_prompts: 3
prompt_duration: 6.0
max_resps: 1
p_resp_append: 0.25
sample_type: speaker
tasks_list: [ "tts" ] # , [ "tts", "tts-c", "ns", "sr", "tse", "cse", "nse", "tts"]
training: []
validation: []
noise: []

Binary file not shown.

View File

@ -1,8 +1,8 @@
import os
import json
input_dataset = "small+medium"
output_dataset = "LibriLight-6K"
input_dataset = "duplicate"
output_dataset = "LibriLight-4K"
for speaker_id in os.listdir(f'./{input_dataset}/'):
if not os.path.isdir(f'./{input_dataset}/{speaker_id}/'):

View File

@ -8,9 +8,14 @@ from pathlib import Path
from vall_e.emb.g2p import encode as valle_phonemize
from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension
input_audio = "voices"
input_audio = "voice"
input_metadata = "metadata"
output_dataset = "training"
output_dataset = "training-24K"
missing = {
"transcription": [],
"audio": []
}
device = "cuda"
@ -31,13 +36,15 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
metadata_path = Path(f'./{input_metadata}/{dataset_name}/{speaker_id}/whisper.json')
if not metadata_path.exists():
print("Does not exist:", metadata_path)
#print("Does not exist:", metadata_path)
missing["transcription"].append(str(metadata_path))
continue
try:
metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read())
except Exception as e:
print("Failed to load metadata:", metadata_path, e)
#print("Failed to load metadata:", metadata_path, e)
missing["transcription"].append(str(metadata_path))
continue
txts = []
@ -46,7 +53,8 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
for filename in metadata.keys():
inpath = Path(f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}')
if not inpath.exists():
print("Does not exist:", inpath)
#print("Does not exist:", inpath)
missing["audio"].append(str(inpath))
continue
extension = os.path.splitext(filename)[-1][1:]
@ -117,21 +125,26 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
waveform[:, start:end],
sample_rate
))
for job in tqdm(txts, desc=f"Phonemizing: {speaker_id}"):
outpath, text, language = job
phones = valle_phonemize(text)
data = {
"text": text.strip(),
"phonemes": phones,
"language": language,
}
open(_replace_file_extension(outpath, ".json"), 'w', encoding='utf-8').write(json.dumps(data))
for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"):
try:
outpath, waveform, sample_rate = job
qnt = valle_quantize(waveform, sr=sample_rate, device=device)
qnt.save(_replace_file_extension(outpath, ".dac"))
except Exception as e:
print(f"Failed to quantize: {outpath}:", e)
continue
if len(txts) > 0:
for job in tqdm(txts, desc=f"Phonemizing: {speaker_id}"):
outpath, text, language = job
phones = valle_phonemize(text)
data = {
"text": text.strip(),
"phonemes": phones,
"language": language,
}
open(_replace_file_extension(outpath, ".json"), 'w', encoding='utf-8').write(json.dumps(data))
if len(wavs) > 0:
for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"):
try:
outpath, waveform, sample_rate = job
qnt = valle_quantize(waveform, sr=sample_rate, device=device)
qnt.save(_replace_file_extension(outpath, ".dac"))
except Exception as e:
print(f"Failed to quantize: {outpath}:", e)
continue
open("./missing.json", 'w', encoding='utf-8').write(json.dumps(missing))

View File

@ -0,0 +1,57 @@
import os
import json
import torch
import torchaudio
from tqdm.auto import tqdm
from pathlib import Path
from tokenizers import Tokenizer
from tokenizers.models import BPE, Unigram, WordLevel, WordPiece
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.processors import TemplateProcessing
input_metadata = "training-24K"
output_file = Path("./dataset.json")
tokenizer_data = []
def pad(num, zeroes):
return str(num).zfill(zeroes+1)
if output_file.exists():
tokenizer_data = json.loads(open(str(output_file), "r", encoding="utf-8").read())
else:
for dataset_name in os.listdir(f'./{input_metadata}/'):
if not os.path.isdir(f'./{input_metadata}/{dataset_name}/'):
continue
for speaker_id in tqdm(os.listdir(f'./{input_metadata}/{dataset_name}/'), desc="Processing speaker"):
if not os.path.isdir(f'./{input_metadata}/{dataset_name}/{speaker_id}'):
continue
for id in os.listdir(f'./{input_metadata}/{dataset_name}/{speaker_id}/'):
if ".json" not in id:
continue
metadata_path = Path(f'./{input_metadata}/{dataset_name}/{speaker_id}/{id}')
metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read())
tokenizer_data.append( f'{"".join(metadata["phonemes"])}' )
open(output_file, 'w', encoding='utf-8').write(json.dumps(tokenizer_data))
unk_token = "<unk>"
spl_tokens = ["<bos>", "</eos>", unk_token, "<mask>"]
trainer = BpeTrainer(special_tokens = spl_tokens, vocab_size = 256)
tokenizer = Tokenizer(BPE(unk_token = unk_token))
tokenizer.pre_tokenizer = Whitespace()
tokenizer.post_processor = TemplateProcessing(
single="<bos> $A <eos>",
special_tokens=[("<bos>", 1), ("<eos>", 2)],
)
tokenizer.train_from_iterator(tokenizer_data, trainer=trainer)
tokenizer.save("./tokenizer.json")

View File

@ -18,6 +18,9 @@ from omegaconf import OmegaConf
from .utils.distributed import world_size
# Yuck
from transformers import PreTrainedTokenizerFast
@dataclass()
class _Config:
cfg_path: str | None = None
@ -540,10 +543,12 @@ class Config(_Config):
inference: Inference = field(default_factory=lambda: Inference)
bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes)
tokenizer: str = "./tokenizer.json"
fp8: FP8 = field(default_factory=lambda: FP8)
sample_rate: int = 24_000
variable_sample_rate: bool = False
variable_sample_rate: bool = True
@property
def distributed(self):
@ -611,16 +616,19 @@ cfg = Config.from_cli()
# OmegaConf might not coerce the dicts into the @dataclass decorated classes, so we (try to) coerce them ourselves
try:
cfg.format()
# cached_property stopped working...
if cfg.dataset.use_hdf5:
cfg.load_hdf5()
except Exception as e:
print(e)
print("Error while parsing config YAML:", e)
pass
try:
from transformers import PreTrainedTokenizerFast
cfg.tokenizer = (cfg.relpath if cfg.cfg_path is not None else Path("./data/")) / cfg.tokenizer
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(cfg.tokenizer))
except Exception as e:
print("Error while parsing tokenizer:", e)
pass
if __name__ == "__main__":
print(cfg)

View File

@ -24,17 +24,17 @@ from torch import Tensor
from torch.utils.data import DataLoader, Dataset as _Dataset
from torch.utils.data.distributed import DistributedSampler
from tqdm.auto import tqdm
# torch.multiprocessing.set_sharing_strategy("file_system")
_logger = logging.getLogger(__name__)
# to-do: clean up this symmap mess
def get_phone_symmap():
if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5:
return json.loads( cfg.hdf5['symmap'].asstr()[()] )
return cfg.tokenizer.get_vocab()
return {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178, '': 179, '': 180, '“ˈ': 181, '“ˌ': 182, ';ˈ': 183, '': 184, ':ˈ': 185, '1': 186, 'rˈ': 187, 'qˈ': 188, 'ᵻˌ': 189, 'ä': 190, '̞ˌ': 191, '̞': 192, 'ũˌ': 193, 'ʑˌ': 194, '': 195, 'ɽ': 196, 'ʲˌ': 197, 'ᵝˌ': 198, 'ũ': 199, 'ũˈ': 200, 'äˌ': 201, 'ɕ': 202, 'ɕˌ': 203, 'ɽˌ': 204, 'çˌ': 205, '…ˌ': 206, '̞ˈ': 207, 'äˈ': 208, 'ɽˈ': 209, 'ɸˌ': 210, 'ɴ': 211, 'ɸˈ': 212, 'ɕˈ': 213, 'ɸ': 214, 'ᵝˈ': 215, 'ʲˈ': 216, 'ĩ': 217, 'çˈ': 218, 'ĩˌ': 219, '': 220, 'eˈ': 221, 'ʍ': 222, '': 223, '': 224, 'ʍˌ': 225, 'uˈ': 226, 'oˈ': 227, 'aˈ': 228}
def tokenize( phones ):
return tokenizer.encode( "".join(phones) )
#return [*map(get_phone_symmap.get, _get_phones(path))]
def get_lang_symmap():
return {
@ -178,7 +178,9 @@ def _get_phones(path, language="en"):
else:
content = open(_get_phone_path(path), "r", encoding="utf-8").read().split(" ")
content = _cleanup_phones( content )
return ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"]
return "".join(content)
#return ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"]
def _interleaved_reorder(l, fn):
groups = defaultdict(list)
@ -435,7 +437,7 @@ class Dataset(_Dataset):
text = torch.from_numpy(text).to(self.text_dtype)
resps = torch.from_numpy(resps).to(torch.int16)
else:
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype)
text = torch.tensor(tokenize( _get_phones( path ) )).to(self.text_dtype)
resps = _load_quants(path)
lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8)
@ -847,18 +849,21 @@ def create_dataset_hdf5( skip_existing=True ):
# audio
if audios:
qnt = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
codes = torch.from_numpy(qnt["codes"].astype(int))[0].t()
codes = torch.from_numpy(qnt["codes"].astype(int))[0].t().to(dtype=torch.int16)
if _get_quant_extension() == ".dac":
if "audio" in group:
del group["audio"]
duration = qnt["metadata"]["original_length"] / qnt["metadata"]["sample_rate"]
metadata[id]["metadata"] = qnt["metadata"]
metadata[id]["metadata"] = {
"original_length": qnt["metadata"]["original_length"],
"sample_rate": qnt["metadata"]["sample_rate"],
}
else:
qnt = torch.load(f'{root}/{name}/{id}{_get_quant_extension()}')[0].t()
duration = qnt.shape[0] / 75
group.create_dataset('audio', data=qnt.numpy(), compression='lzf')
group.create_dataset('audio', data=qnt.numpy().astype(np.int16), compression='lzf')
group.attrs['duration'] = duration
metadata[id]["duration"] = duration
@ -869,17 +874,22 @@ def create_dataset_hdf5( skip_existing=True ):
# text
if texts:
if _get_quant_extension() == ".json":
j_son = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
content = j_son["phonemes"]
json_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
content = json_metadata["phonemes"]
else:
content = open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read().split(" ")
"""
phones = [f"<s>"] + [ " " if not p else p for p in content ] + [f"</s>"]
for s in set(phones):
if s not in symmap:
symmap[s] = len(symmap.keys())
phn = [ symmap[s] for s in phones ]
"""
phn = cfg.tokenizer.encode("".join(content))
phn = np.array(phn).astype(np.uint8)
if "text" in group:
del group["text"]

View File

@ -91,15 +91,8 @@ class TTS():
return text
content = g2p.encode(text, language=language)
content = _cleanup_phones( content )
# ick
try:
phones = ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"]
return torch.tensor([*map(self.symmap.get, phones)])
except Exception as e:
pass
phones = [ " " if not p else p for p in content ]
return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ])
return torch.tensor(cfg.tokenizer.encode( "".join( content ) ))
def encode_lang( self, language ):
symmap = get_lang_symmap()