dataset preparation script updates, caved and am using HF tokenizer now

This commit is contained in:
mrq 2024-04-21 14:49:18 -05:00
parent a8ffa88844
commit 071fb97777
8 changed files with 211 additions and 107 deletions

View File

@ -1,51 +1,23 @@
dataset:
training: []
validation: []
noise: []
speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
use_hdf5: True
use_metadata: True
hdf5_flag: r
validate: True
workers: 2
cache: True
phones_range: [4, 512]
duration_range: [1.0, 32.0]
random_utterance: 1.0
max_prompts: 3
prompt_duration: 6.0
sample_type: speaker
tasks_list: [ "tts" ] # , [ "tts", "tts-c", "ns", "sr", "tse", "cse", "nse", "tts"]
models: models:
_prom_levels: 8 - name: "ar+nar"
_max_levels: 8 size: "full"
resp_levels: 8
_models: prom_levels: 8
- name: "ar+nar" tasks: 8
size: "full" langs: 2
resp_levels: 8 tones: 1
prom_levels: 8 arch_type: "retnet"
tasks: 8 training: True
arch_type: "retnet" version: 3
training: True
version: 3
hyperparameters: hyperparameters:
batch_size: 8 batch_size: 4
gradient_accumulation_steps: 32 gradient_accumulation_steps: 4
gradient_clipping: 100 gradient_clipping: 10
optimizer: Prodigy optimizer: Adagrad
torch_optimizer: True torch_optimizer: True
learning_rate: 0.0625 learning_rate: 1.0e-2
scheduler_type: "" scheduler_type: ""
#scheduler_type: OneCycle #scheduler_type: OneCycle
@ -67,22 +39,24 @@ hyperparameters:
# decay_mom_rate: 0.0 # decay_mom_rate: 0.0
evaluation: evaluation:
batch_size: 16 batch_size: 8
frequency: 250 frequency: 10000
size: 16 size: 8
steps: 450 steps: 500
ar_temperature: 0.95 ar_temperature: 0.95
nar_temperature: 0.25 nar_temperature: 0.25
load_disabled_engines: True load_disabled_engines: True
trainer: trainer:
no_logger: True
iterations: 1_000_000 iterations: 1_000_000
save_tag: step save_tag: step
save_on_oom: True save_on_oom: True
save_on_quit: True save_on_quit: True
save_frequency: 100 save_frequency: 250
export_on_save: True export_on_save: True
keep_last_checkpoints: 4 keep_last_checkpoints: 4
@ -91,33 +65,82 @@ trainer:
load_disabled_engines: False load_disabled_engines: False
#load_state_dict: True #load_state_dict: True
#strict_loading: False strict_loading: False
#load_tag: "9500" #load_tag: "9500"
#load_states: False #load_states: False
#restart_step_count: True #restart_step_count: True
gc_mode: None # "global_step" gc_mode: None # "global_step"
weight_dtype: bfloat16 weight_dtype: float32
amp: False amp: False
backend: deepspeed backend: deepspeed
deepspeed: deepspeed:
inferencing: True
zero_optimization_level: 0 zero_optimization_level: 0
use_compression_training: True use_compression_training: False
activation_checkpointing: True activation_checkpointing: True
load_webui: True
inference: inference:
use_vocos: True backend: deepspeed
audio_backend: "dac"
normalize: False normalize: False
weight_dtype: bfloat16 weight_dtype: float32
amp: False amp: False
bitsandbytes: bitsandbytes:
enabled: False enabled: False
injects: True
linear: True injects: False
embedding: True replace: False
linear: False
embedding: False
bitnet: False
fp8:
enabled: False
backend: "te"
experimental: True
dataset:
speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
speaker_group_getter: "lambda p: f'{p.parts[-3]}'"
speaker_languages:
ja: []
use_hdf5: True
use_metadata: True
hdf5_flag: r
validate: True
workers: 8
cache: True
#phones_range: [4, 512]
#duration_range: [1.0, 32.0]
phones_range: [0, 512]
duration_range: [0.0, 64.0]
random_utterance: 1.0
max_prompts: 3
prompt_duration: 6.0
max_resps: 1
p_resp_append: 0.25
sample_type: speaker
tasks_list: [ "tts" ] # , [ "tts", "tts-c", "ns", "sr", "tse", "cse", "nse", "tts"]
training: []
validation: []
noise: []

Binary file not shown.

View File

@ -1,8 +1,8 @@
import os import os
import json import json
input_dataset = "small+medium" input_dataset = "duplicate"
output_dataset = "LibriLight-6K" output_dataset = "LibriLight-4K"
for speaker_id in os.listdir(f'./{input_dataset}/'): for speaker_id in os.listdir(f'./{input_dataset}/'):
if not os.path.isdir(f'./{input_dataset}/{speaker_id}/'): if not os.path.isdir(f'./{input_dataset}/{speaker_id}/'):

View File

@ -8,9 +8,14 @@ from pathlib import Path
from vall_e.emb.g2p import encode as valle_phonemize from vall_e.emb.g2p import encode as valle_phonemize
from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension
input_audio = "voices" input_audio = "voice"
input_metadata = "metadata" input_metadata = "metadata"
output_dataset = "training" output_dataset = "training-24K"
missing = {
"transcription": [],
"audio": []
}
device = "cuda" device = "cuda"
@ -31,13 +36,15 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
metadata_path = Path(f'./{input_metadata}/{dataset_name}/{speaker_id}/whisper.json') metadata_path = Path(f'./{input_metadata}/{dataset_name}/{speaker_id}/whisper.json')
if not metadata_path.exists(): if not metadata_path.exists():
print("Does not exist:", metadata_path) #print("Does not exist:", metadata_path)
missing["transcription"].append(str(metadata_path))
continue continue
try: try:
metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read()) metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read())
except Exception as e: except Exception as e:
print("Failed to load metadata:", metadata_path, e) #print("Failed to load metadata:", metadata_path, e)
missing["transcription"].append(str(metadata_path))
continue continue
txts = [] txts = []
@ -46,7 +53,8 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
for filename in metadata.keys(): for filename in metadata.keys():
inpath = Path(f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}') inpath = Path(f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}')
if not inpath.exists(): if not inpath.exists():
print("Does not exist:", inpath) #print("Does not exist:", inpath)
missing["audio"].append(str(inpath))
continue continue
extension = os.path.splitext(filename)[-1][1:] extension = os.path.splitext(filename)[-1][1:]
@ -117,21 +125,26 @@ for dataset_name in os.listdir(f'./{input_audio}/'):
waveform[:, start:end], waveform[:, start:end],
sample_rate sample_rate
)) ))
for job in tqdm(txts, desc=f"Phonemizing: {speaker_id}"):
outpath, text, language = job
phones = valle_phonemize(text)
data = {
"text": text.strip(),
"phonemes": phones,
"language": language,
}
open(_replace_file_extension(outpath, ".json"), 'w', encoding='utf-8').write(json.dumps(data))
for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"): if len(txts) > 0:
try: for job in tqdm(txts, desc=f"Phonemizing: {speaker_id}"):
outpath, waveform, sample_rate = job outpath, text, language = job
qnt = valle_quantize(waveform, sr=sample_rate, device=device) phones = valle_phonemize(text)
qnt.save(_replace_file_extension(outpath, ".dac")) data = {
except Exception as e: "text": text.strip(),
print(f"Failed to quantize: {outpath}:", e) "phonemes": phones,
continue "language": language,
}
open(_replace_file_extension(outpath, ".json"), 'w', encoding='utf-8').write(json.dumps(data))
if len(wavs) > 0:
for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"):
try:
outpath, waveform, sample_rate = job
qnt = valle_quantize(waveform, sr=sample_rate, device=device)
qnt.save(_replace_file_extension(outpath, ".dac"))
except Exception as e:
print(f"Failed to quantize: {outpath}:", e)
continue
open("./missing.json", 'w', encoding='utf-8').write(json.dumps(missing))

View File

@ -0,0 +1,57 @@
import os
import json
import torch
import torchaudio
from tqdm.auto import tqdm
from pathlib import Path
from tokenizers import Tokenizer
from tokenizers.models import BPE, Unigram, WordLevel, WordPiece
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.processors import TemplateProcessing
input_metadata = "training-24K"
output_file = Path("./dataset.json")
tokenizer_data = []
def pad(num, zeroes):
return str(num).zfill(zeroes+1)
if output_file.exists():
tokenizer_data = json.loads(open(str(output_file), "r", encoding="utf-8").read())
else:
for dataset_name in os.listdir(f'./{input_metadata}/'):
if not os.path.isdir(f'./{input_metadata}/{dataset_name}/'):
continue
for speaker_id in tqdm(os.listdir(f'./{input_metadata}/{dataset_name}/'), desc="Processing speaker"):
if not os.path.isdir(f'./{input_metadata}/{dataset_name}/{speaker_id}'):
continue
for id in os.listdir(f'./{input_metadata}/{dataset_name}/{speaker_id}/'):
if ".json" not in id:
continue
metadata_path = Path(f'./{input_metadata}/{dataset_name}/{speaker_id}/{id}')
metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read())
tokenizer_data.append( f'{"".join(metadata["phonemes"])}' )
open(output_file, 'w', encoding='utf-8').write(json.dumps(tokenizer_data))
unk_token = "<unk>"
spl_tokens = ["<bos>", "</eos>", unk_token, "<mask>"]
trainer = BpeTrainer(special_tokens = spl_tokens, vocab_size = 256)
tokenizer = Tokenizer(BPE(unk_token = unk_token))
tokenizer.pre_tokenizer = Whitespace()
tokenizer.post_processor = TemplateProcessing(
single="<bos> $A <eos>",
special_tokens=[("<bos>", 1), ("<eos>", 2)],
)
tokenizer.train_from_iterator(tokenizer_data, trainer=trainer)
tokenizer.save("./tokenizer.json")

View File

@ -18,6 +18,9 @@ from omegaconf import OmegaConf
from .utils.distributed import world_size from .utils.distributed import world_size
# Yuck
from transformers import PreTrainedTokenizerFast
@dataclass() @dataclass()
class _Config: class _Config:
cfg_path: str | None = None cfg_path: str | None = None
@ -540,10 +543,12 @@ class Config(_Config):
inference: Inference = field(default_factory=lambda: Inference) inference: Inference = field(default_factory=lambda: Inference)
bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes) bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes)
tokenizer: str = "./tokenizer.json"
fp8: FP8 = field(default_factory=lambda: FP8) fp8: FP8 = field(default_factory=lambda: FP8)
sample_rate: int = 24_000 sample_rate: int = 24_000
variable_sample_rate: bool = False variable_sample_rate: bool = True
@property @property
def distributed(self): def distributed(self):
@ -611,16 +616,19 @@ cfg = Config.from_cli()
# OmegaConf might not coerce the dicts into the @dataclass decorated classes, so we (try to) coerce them ourselves # OmegaConf might not coerce the dicts into the @dataclass decorated classes, so we (try to) coerce them ourselves
try: try:
cfg.format() cfg.format()
# cached_property stopped working...
if cfg.dataset.use_hdf5: if cfg.dataset.use_hdf5:
cfg.load_hdf5() cfg.load_hdf5()
except Exception as e: except Exception as e:
print(e) print("Error while parsing config YAML:", e)
pass pass
try:
from transformers import PreTrainedTokenizerFast
cfg.tokenizer = (cfg.relpath if cfg.cfg_path is not None else Path("./data/")) / cfg.tokenizer
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(cfg.tokenizer))
except Exception as e:
print("Error while parsing tokenizer:", e)
pass
if __name__ == "__main__": if __name__ == "__main__":
print(cfg) print(cfg)

View File

@ -24,17 +24,17 @@ from torch import Tensor
from torch.utils.data import DataLoader, Dataset as _Dataset from torch.utils.data import DataLoader, Dataset as _Dataset
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm.auto import tqdm from tqdm.auto import tqdm
# torch.multiprocessing.set_sharing_strategy("file_system") # torch.multiprocessing.set_sharing_strategy("file_system")
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
# to-do: clean up this symmap mess # to-do: clean up this symmap mess
def get_phone_symmap(): def get_phone_symmap():
if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5: return cfg.tokenizer.get_vocab()
return json.loads( cfg.hdf5['symmap'].asstr()[()] )
return {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178, '': 179, '': 180, '“ˈ': 181, '“ˌ': 182, ';ˈ': 183, '': 184, ':ˈ': 185, '1': 186, 'rˈ': 187, 'qˈ': 188, 'ᵻˌ': 189, 'ä': 190, '̞ˌ': 191, '̞': 192, 'ũˌ': 193, 'ʑˌ': 194, '': 195, 'ɽ': 196, 'ʲˌ': 197, 'ᵝˌ': 198, 'ũ': 199, 'ũˈ': 200, 'äˌ': 201, 'ɕ': 202, 'ɕˌ': 203, 'ɽˌ': 204, 'çˌ': 205, '…ˌ': 206, '̞ˈ': 207, 'äˈ': 208, 'ɽˈ': 209, 'ɸˌ': 210, 'ɴ': 211, 'ɸˈ': 212, 'ɕˈ': 213, 'ɸ': 214, 'ᵝˈ': 215, 'ʲˈ': 216, 'ĩ': 217, 'çˈ': 218, 'ĩˌ': 219, '': 220, 'eˈ': 221, 'ʍ': 222, '': 223, '': 224, 'ʍˌ': 225, 'uˈ': 226, 'oˈ': 227, 'aˈ': 228} def tokenize( phones ):
return tokenizer.encode( "".join(phones) )
#return [*map(get_phone_symmap.get, _get_phones(path))]
def get_lang_symmap(): def get_lang_symmap():
return { return {
@ -178,7 +178,9 @@ def _get_phones(path, language="en"):
else: else:
content = open(_get_phone_path(path), "r", encoding="utf-8").read().split(" ") content = open(_get_phone_path(path), "r", encoding="utf-8").read().split(" ")
content = _cleanup_phones( content ) content = _cleanup_phones( content )
return ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"]
return "".join(content)
#return ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"]
def _interleaved_reorder(l, fn): def _interleaved_reorder(l, fn):
groups = defaultdict(list) groups = defaultdict(list)
@ -435,7 +437,7 @@ class Dataset(_Dataset):
text = torch.from_numpy(text).to(self.text_dtype) text = torch.from_numpy(text).to(self.text_dtype)
resps = torch.from_numpy(resps).to(torch.int16) resps = torch.from_numpy(resps).to(torch.int16)
else: else:
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype) text = torch.tensor(tokenize( _get_phones( path ) )).to(self.text_dtype)
resps = _load_quants(path) resps = _load_quants(path)
lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8) lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8)
@ -847,18 +849,21 @@ def create_dataset_hdf5( skip_existing=True ):
# audio # audio
if audios: if audios:
qnt = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()] qnt = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
codes = torch.from_numpy(qnt["codes"].astype(int))[0].t() codes = torch.from_numpy(qnt["codes"].astype(int))[0].t().to(dtype=torch.int16)
if _get_quant_extension() == ".dac": if _get_quant_extension() == ".dac":
if "audio" in group: if "audio" in group:
del group["audio"] del group["audio"]
duration = qnt["metadata"]["original_length"] / qnt["metadata"]["sample_rate"] duration = qnt["metadata"]["original_length"] / qnt["metadata"]["sample_rate"]
metadata[id]["metadata"] = qnt["metadata"] metadata[id]["metadata"] = {
"original_length": qnt["metadata"]["original_length"],
"sample_rate": qnt["metadata"]["sample_rate"],
}
else: else:
qnt = torch.load(f'{root}/{name}/{id}{_get_quant_extension()}')[0].t() qnt = torch.load(f'{root}/{name}/{id}{_get_quant_extension()}')[0].t()
duration = qnt.shape[0] / 75 duration = qnt.shape[0] / 75
group.create_dataset('audio', data=qnt.numpy(), compression='lzf') group.create_dataset('audio', data=qnt.numpy().astype(np.int16), compression='lzf')
group.attrs['duration'] = duration group.attrs['duration'] = duration
metadata[id]["duration"] = duration metadata[id]["duration"] = duration
@ -869,17 +874,22 @@ def create_dataset_hdf5( skip_existing=True ):
# text # text
if texts: if texts:
if _get_quant_extension() == ".json": if _get_quant_extension() == ".json":
j_son = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read()) json_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
content = j_son["phonemes"] content = json_metadata["phonemes"]
else: else:
content = open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read().split(" ") content = open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read().split(" ")
"""
phones = [f"<s>"] + [ " " if not p else p for p in content ] + [f"</s>"] phones = [f"<s>"] + [ " " if not p else p for p in content ] + [f"</s>"]
for s in set(phones): for s in set(phones):
if s not in symmap: if s not in symmap:
symmap[s] = len(symmap.keys()) symmap[s] = len(symmap.keys())
phn = [ symmap[s] for s in phones ] phn = [ symmap[s] for s in phones ]
"""
phn = cfg.tokenizer.encode("".join(content))
phn = np.array(phn).astype(np.uint8)
if "text" in group: if "text" in group:
del group["text"] del group["text"]

View File

@ -91,15 +91,8 @@ class TTS():
return text return text
content = g2p.encode(text, language=language) content = g2p.encode(text, language=language)
content = _cleanup_phones( content )
# ick return torch.tensor(cfg.tokenizer.encode( "".join( content ) ))
try:
phones = ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"]
return torch.tensor([*map(self.symmap.get, phones)])
except Exception as e:
pass
phones = [ " " if not p else p for p in content ]
return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ])
def encode_lang( self, language ): def encode_lang( self, language ):
symmap = get_lang_symmap() symmap = get_lang_symmap()