diff --git a/scripts/cleanup_dataset.py b/scripts/cleanup_dataset.py index 04eca87..0915888 100644 --- a/scripts/cleanup_dataset.py +++ b/scripts/cleanup_dataset.py @@ -1,3 +1,7 @@ +""" +# Helper script to clean up transcription metadata, whatever that entailed. +""" + import os import json import torch diff --git a/scripts/deduplicate_librilight_libritts.py b/scripts/deduplicate_librilight_libritts.py index bb6e483..be279f4 100755 --- a/scripts/deduplicate_librilight_libritts.py +++ b/scripts/deduplicate_librilight_libritts.py @@ -1,3 +1,7 @@ +""" +# Helper script to try and detect any duplications between LibriLight and LibriTTS (I don't think there were any) +""" + import os import json diff --git a/scripts/parse_ppp.py b/scripts/parse_ppp.py index 51f2300..fbdcb33 100644 --- a/scripts/parse_ppp.py +++ b/scripts/parse_ppp.py @@ -1,3 +1,7 @@ +""" +# Helper script to parse PPP dataset into a friendlier hierarchy +""" + import os import json import torch @@ -7,8 +11,6 @@ from pathlib import Path from vall_e.emb.g2p import encode as valle_phonemize from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension -device = "cuda" - target = "in" audio_map = {} @@ -86,6 +88,10 @@ for key, entry in audio_map.items(): for name in data.keys(): open(f"./training/{name}/whisper.json", "w", encoding="utf-8").write( json.dumps( data[name], indent='\t' ) ) +# to-do: update to "The Proper Way" +# for now it can just be fed back into "The Proper Way"" +""" +device = "cuda" for key, text in tqdm(txts.items(), desc="Phonemizing..."): path = Path(key) phones = valle_phonemize(text) @@ -93,4 +99,5 @@ for key, text in tqdm(txts.items(), desc="Phonemizing..."): for path in tqdm(wavs, desc="Quantizing..."): qnt = valle_quantize(path, device=device) - torch.save(qnt.cpu(), _replace_file_extension(path, ".qnt.pt")) \ No newline at end of file + torch.save(qnt.cpu(), _replace_file_extension(path, ".qnt.pt")) +""" \ No newline at end of file diff --git a/scripts/prepare_librilight.py b/scripts/prepare_librilight.py index c9ca16d..f82b8ef 100755 --- a/scripts/prepare_librilight.py +++ b/scripts/prepare_librilight.py @@ -1,32 +1,41 @@ +""" +# Handles processing `facebookresearch/libri-light`'s unlabeled audio into a friendlier hierarchy +""" + import os import json -input_dataset = "duplicate" +datasets = ["small", "medium", "large", "duplicate"] output_dataset = "LibriLight-4K" -for speaker_id in os.listdir(f'./{input_dataset}/'): - if not os.path.isdir(f'./{input_dataset}/{speaker_id}/'): +for input_dataset in datasets: + if not os.path.isdir(f'./{input_dataset}/'): continue - for book_name in os.listdir(f'./{input_dataset}/{speaker_id}/'): - subid = 0 - for filename in os.listdir(f'./{input_dataset}/{speaker_id}/{book_name}'): - if filename[-5:] != ".json": - continue + for speaker_id in os.listdir(f'./{input_dataset}/'): + if not os.path.isdir(f'./{input_dataset}/{speaker_id}/'): + continue + + for book_name in os.listdir(f'./{input_dataset}/{speaker_id}/'): + subid = 0 - basename = filename[:-5] + for filename in os.listdir(f'./{input_dataset}/{speaker_id}/{book_name}'): + if filename[-5:] != ".json": + continue - json_path = f'./{input_dataset}/{speaker_id}/{book_name}/{basename}.json' - flac_path = f'./{input_dataset}/{speaker_id}/{book_name}/{basename}.flac' + basename = filename[:-5] - j = json.load(open(json_path, 'r', encoding="utf-8")) - id = j['book_meta']['id'] - - json_id_path = f'./{output_dataset}/{speaker_id}/{speaker_id}_{id}_{subid}.json' - flac_id_path = f'./{output_dataset}/{speaker_id}/{speaker_id}_{id}_{subid}.flac' + json_path = f'./{input_dataset}/{speaker_id}/{book_name}/{basename}.json' + flac_path = f'./{input_dataset}/{speaker_id}/{book_name}/{basename}.flac' - os.makedirs(f'./{output_dataset}/{speaker_id}/', exist_ok=True) - os.rename(json_path, json_id_path) - os.rename(flac_path, flac_id_path) + j = json.load(open(json_path, 'r', encoding="utf-8")) + id = j['book_meta']['id'] + + json_id_path = f'./{output_dataset}/{speaker_id}/{speaker_id}_{id}_{subid}.json' + flac_id_path = f'./{output_dataset}/{speaker_id}/{speaker_id}_{id}_{subid}.flac' - subid += 1 + os.makedirs(f'./{output_dataset}/{speaker_id}/', exist_ok=True) + os.rename(json_path, json_id_path) + os.rename(flac_path, flac_id_path) + + subid += 1 diff --git a/scripts/prepare_libritts.py b/scripts/prepare_libritts.py deleted file mode 100755 index d15c24b..0000000 --- a/scripts/prepare_libritts.py +++ /dev/null @@ -1,21 +0,0 @@ -import os -import json - -input_dataset = "LibriTTS_R" -output_dataset = "LibriTTS-Train" - -for dataset_name in os.listdir(f'./{input_dataset}/'): - if not os.path.isdir(f'./{input_dataset}/{dataset_name}/'): - continue - for speaker_id in os.listdir(f'./{input_dataset}/{dataset_name}/'): - if not os.path.isdir(f'./{input_dataset}/{dataset_name}/{speaker_id}'): - continue - for book_id in os.listdir(f'./{input_dataset}/{dataset_name}/{speaker_id}'): - if not os.path.isdir(f'./{input_dataset}/{dataset_name}/{speaker_id}/{book_id}'): - continue - for filename in os.listdir(f'./{input_dataset}/{dataset_name}/{speaker_id}/{book_id}'): - if filename[-4:] != ".wav": - continue - - os.makedirs(f'./{output_dataset}/{speaker_id}/', exist_ok=True) - os.rename(f'./{input_dataset}/{dataset_name}/{speaker_id}/{book_id}/{filename}', f'./{output_dataset}/{speaker_id}/{filename}') \ No newline at end of file diff --git a/scripts/process_libritts.py b/scripts/process_libritts.py index b3b1283..5d8b06c 100755 --- a/scripts/process_libritts.py +++ b/scripts/process_libritts.py @@ -1,6 +1,13 @@ +""" +# Handles processing audio provided through --input-audio of adequately annotated transcriptions provided through --input-metadata (through transcribe.py) +# Outputs NumPy objects containing quantized audio and adequate metadata for use of loading in the trainer through --output-dataset +""" + import os import json +import argparse import torch +import torchaudio import numpy as np from tqdm.auto import tqdm @@ -8,104 +15,215 @@ from pathlib import Path from vall_e.config import cfg -# things that could be args -cfg.sample_rate = 24_000 -cfg.audio_backend = "encodec" -""" -cfg.inference.weight_dtype = "bfloat16" -cfg.inference.dtype = torch.bfloat16 -cfg.inference.amp = True -""" +def pad(num, zeroes): + return str(num).zfill(zeroes+1) -from vall_e.emb.g2p import encode as valle_phonemize -from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension +def process_items( items, stride=0, stride_offset=0 ): + items = sorted( items ) + return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ] -audio_extension = ".enc" -if cfg.audio_backend == "dac": - audio_extension = ".dac" -elif cfg.audio_backend == "audiodec": - audio_extension = ".dec" +def process( + audio_backend="encodec", + input_audio="LibriTTS_R", + output_dataset="training", + raise_exceptions=False, + stride=0, + stride_offset=0, + slice="auto", -input_dataset = "LibriTTS_R" -output_dataset = f"LibriTTS-Train-{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" -device = "cuda" + device="cuda", + dtype="float16", + amp=False, + ): + # encodec / vocos -txts = [] -wavs = [] + if audio_backend in ["encodec", "vocos"]: + audio_extension = ".enc" + cfg.sample_rate = 24_000 + cfg.model.resp_levels = 8 + elif audio_backend == "dac": + audio_extension = ".dac" + cfg.sample_rate = 44_100 + cfg.model.resp_levels = 9 + elif cfg.audio_backend == "audiodec": + sample_rate = 48_000 + audio_extension = ".dec" + cfg.model.resp_levels = 8 # ? + else: + raise Exception(f"Unknown audio backend: {audio_backend}") -for dataset_name in os.listdir(f'./{input_dataset}/'): - if not os.path.isdir(f'./{input_dataset}/{dataset_name}/'): - continue + # prepare from args + cfg.audio_backend = audio_backend # "encodec" + cfg.inference.weight_dtype = dtype # "bfloat16" + cfg.inference.amp = amp # False - for speaker_id in tqdm(os.listdir(f'./{input_dataset}/{dataset_name}/'), desc="Processing speaker"): - if not os.path.isdir(f'./{input_dataset}/{dataset_name}/{speaker_id}'): + # import after because we've overriden the config above + # need to validate if this is even necessary anymore + from vall_e.emb.g2p import encode as phonemize + from vall_e.emb.qnt import encode as quantize, _replace_file_extension + + output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training" + + language_map = {} # k = group, v = language + + ignore_groups = [] # skip these groups + ignore_speakers = [] # skip these speakers + + only_groups = [] # only process these groups + only_speakers = [] # only process these speakers + + always_slice_groups = [] # always slice from this group + + missing = { + "transcription": [], + "audio": [] + } + dataset = [] + + # Layout: ./LibriTTS_R/train-clean-100/103/1241 + for group_name in sorted(os.listdir(f'./{input_audio}/')): + if not os.path.isdir(f'./{input_audio}/{group_name}/'): + print("Is not dir:", f'./{input_audio}/{group_name}/') continue - - os.makedirs(f'./{output_dataset}/{speaker_id}/', exist_ok=True) - for book_id in os.listdir(f'./{input_dataset}/{dataset_name}/{speaker_id}'): - if not os.path.isdir(f'./{input_dataset}/{dataset_name}/{speaker_id}/{book_id}'): - continue - for filename in os.listdir(f'./{input_dataset}/{dataset_name}/{speaker_id}/{book_id}'): - # os.rename(f'./{input_dataset}/{dataset_name}/{speaker_id}/{book_id}/{filename}', f'./{output_dataset}/{speaker_id}/{filename}') - inpath = Path(f'./{input_dataset}/{dataset_name}/{speaker_id}/{book_id}/{filename}') - outpath = Path(f'./{output_dataset}/{speaker_id}/{filename}') + if group_name in ignore_groups: + continue + if only_groups and group_name not in only_groups: + continue + + for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{group_name}/'), stride=stride, stride_offset=stride_offset), desc=f"Processing speaker in {group_name}"): + if not os.path.isdir(f'./{input_audio}/{group_name}/{speaker_id}'): + print("Is not dir:", f'./{input_audio}/{group_name}/{speaker_id}') + continue + + if speaker_id in ignore_speakers: + continue + if only_speakers and speaker_id not in only_speakers: + continue + + os.makedirs(f'./{output_dataset}/{group_name}/{speaker_id}/', exist_ok=True) + + if f'{group_name}/{speaker_id}' not in dataset: + dataset.append(f'{group_name}/{speaker_id}') + + txts = [] + wavs = [] + + for book_id in os.listdir(f'./{input_audio}/{dataset_name}/{speaker_id}'): + if not os.path.isdir(f'./{input_audio}/{group_name}/{speaker_id}/{book_id}'): + print("Is not dir:", f'./{input_audio}/{group_name}/{speaker_id}/{book_id}') + continue + + for filename in os.listdir(f'./{input_audio}/{dataset_name}/{speaker_id}/{book_id}'): + inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{book_id}/{filename}') + if not inpath.exists(): + missing["audio"].append(str(inpath)) - if ".wav" in filename: # and not _replace_file_extension(outpath, ".dac").exists(): - txts.append(( - inpath, - outpath + extension = os.path.splitext(filename)[-1][1:] + fname = filename.replace(f'.{extension}', "") + + waveform, sample_rate = None, None + language = "en" + + outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}') + text = open(_replace_file_extension(inpath, ".original.txt"), "r", encoding="utf-8").read() + + if len(text) == 0: + continue + + if _replace_file_extension(outpath, audio_extension).exists(): + continue + + if waveform is None: + waveform, sample_rate = torchaudio.load(inpath) + if waveform.shape[0] > 1: + waveform = torch.mean(waveform, dim=0, keepdim=True) + + wavs.append(( + outpath, + text, + language, + waveform, + sample_rate )) -for paths in tqdm(txts, desc="Processing..."): - inpath, outpath = paths - try: - if _replace_file_extension(outpath, ".dac").exists() and _replace_file_extension(outpath, ".json").exists(): - data = json.loads(open(_replace_file_extension(outpath, ".json"), 'r', encoding='utf-8').read()) - qnt = np.load(_replace_file_extension(outpath, audio_extension), allow_pickle=True) - - if not isinstance(data["phonemes"], str): - data["phonemes"] = "".join(data["phonemes"]) + if len(wavs) > 0: + for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"): + try: + outpath, text, language, waveform, sample_rate = job - for k, v in data.items(): - qnt[()]['metadata'][k] = v + phones = phonemize(text, language=language) + qnt = quantize(waveform, sr=sample_rate, device=device) - np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), qnt) - else: - text = open(_replace_file_extension(inpath, ".original.txt"), "r", encoding="utf-8").read() - - phones = valle_phonemize(text) - qnt = valle_quantize(_replace_file_extension(inpath, ".wav"), device=device) - if cfg.audio_backend == "dac": - np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { - "codes": qnt.codes.cpu().numpy().astype(np.uint16), - "metadata": { - "original_length": qnt.original_length, - "sample_rate": qnt.sample_rate, - - "input_db": qnt.input_db.cpu().numpy().astype(np.float32), - "chunk_length": qnt.chunk_length, - "channels": qnt.channels, - "padding": qnt.padding, - "dac_version": "1.0.0", + if cfg.audio_backend == "dac": + np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { + "codes": qnt.codes.cpu().numpy().astype(np.uint16), + "metadata": { + "original_length": qnt.original_length, + "sample_rate": qnt.sample_rate, + + "input_db": qnt.input_db.cpu().numpy().astype(np.float32), + "chunk_length": qnt.chunk_length, + "channels": qnt.channels, + "padding": qnt.padding, + "dac_version": "1.0.0", - "text": text.strip(), - "phonemes": "".join(phones), - "language": "en", - }, - }) - else: - np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { - "codes": qnt.cpu().numpy().astype(np.uint16), - "metadata": { - "original_length": qnt.shape[-1] / 75.0, - "sample_rate": cfg.sample_rate, + "text": text.strip(), + "phonemes": "".join(phones), + "language": language, + }, + }) + else: + np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { + "codes": qnt.cpu().numpy().astype(np.uint16), + "metadata": { + "original_length": waveform.shape[-1], + "sample_rate": sample_rate, - "text": text.strip(), - "phonemes": "".join(phones), - "language": "en", - }, - }) - except Exception as e: - tqdm.write(f"Failed to process: {paths}: {e}") + "text": text.strip(), + "phonemes": "".join(phones), + "language": language, + }, + }) + except Exception as e: + print(f"Failed to quantize: {outpath}:", e) + if raise_exceptions: + raise e + continue + + open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing)) + open(f"./{output_dataset}/dataset.json", 'w', encoding='utf-8').write(json.dumps(dataset)) + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument("--audio-backend", type=str, default="encodec") + parser.add_argument("--dtype", type=str, default="bfloat16") + parser.add_argument("--amp", action="store_true") + parser.add_argument("--input-audio", type=str, default="LibriTTS_R") + parser.add_argument("--output-dataset", type=str, default="training/dataset") + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--raise-exceptions", action="store_true") + parser.add_argument("--stride", type=int, default=0) + parser.add_argument("--stride-offset", type=int, default=0) + parser.add_argument("--slice", type=str, default="auto") + + args = parser.parse_args() + + process( + audio_backend=args.audio_backend, + input_audio=args.input_audio, + output_dataset=args.output_dataset, + raise_exceptions=args.raise_exceptions, + stride=args.stride, + stride_offset=args.stride_offset, + slice=args.slice, + + device=args.device, + dtype=args.dtype, + amp=args.amp, + ) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/train_tokenizer.py b/scripts/train_tokenizer.py index 57c0f95..b2f6c9b 100644 --- a/scripts/train_tokenizer.py +++ b/scripts/train_tokenizer.py @@ -1,3 +1,7 @@ +""" +# Helper script to grab all phonemes through parsed dataset metadata to find the "best" tokenizer dict +""" + import os import json import torch diff --git a/vall_e/emb/process.py b/vall_e/emb/process.py index 2e4f244..3e7b767 100644 --- a/vall_e/emb/process.py +++ b/vall_e/emb/process.py @@ -59,6 +59,7 @@ def process( cfg.inference.amp = amp # False # import after because we've overriden the config above + # need to validate if this is even necessary anymore from .g2p import encode as phonemize from .qnt import encode as quantize, _replace_file_extension @@ -275,8 +276,8 @@ def process( raise e continue - open("./missing.json", 'w', encoding='utf-8').write(json.dumps(missing)) - open("./dataset_list.json", 'w', encoding='utf-8').write(json.dumps(dataset)) + open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing)) + open(f"./{output_dataset}/dataset.json", 'w', encoding='utf-8').write(json.dumps(dataset)) def main(): parser = argparse.ArgumentParser()