re-adapted process_libritts.py to a 'better' way (better because it processed without needing to shuffle a bunch of things and adapt to cope or something)

This commit is contained in:
mrq 2024-08-05 20:34:58 -05:00
parent 3f73fcca29
commit 134dac8c2b
8 changed files with 257 additions and 131 deletions

View File

@ -1,3 +1,7 @@
"""
# Helper script to clean up transcription metadata, whatever that entailed.
"""
import os import os
import json import json
import torch import torch

View File

@ -1,3 +1,7 @@
"""
# Helper script to try and detect any duplications between LibriLight and LibriTTS (I don't think there were any)
"""
import os import os
import json import json

View File

@ -1,3 +1,7 @@
"""
# Helper script to parse PPP dataset into a friendlier hierarchy
"""
import os import os
import json import json
import torch import torch
@ -7,8 +11,6 @@ from pathlib import Path
from vall_e.emb.g2p import encode as valle_phonemize from vall_e.emb.g2p import encode as valle_phonemize
from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension
device = "cuda"
target = "in" target = "in"
audio_map = {} audio_map = {}
@ -86,6 +88,10 @@ for key, entry in audio_map.items():
for name in data.keys(): for name in data.keys():
open(f"./training/{name}/whisper.json", "w", encoding="utf-8").write( json.dumps( data[name], indent='\t' ) ) open(f"./training/{name}/whisper.json", "w", encoding="utf-8").write( json.dumps( data[name], indent='\t' ) )
# to-do: update to "The Proper Way"
# for now it can just be fed back into "The Proper Way""
"""
device = "cuda"
for key, text in tqdm(txts.items(), desc="Phonemizing..."): for key, text in tqdm(txts.items(), desc="Phonemizing..."):
path = Path(key) path = Path(key)
phones = valle_phonemize(text) phones = valle_phonemize(text)
@ -94,3 +100,4 @@ for key, text in tqdm(txts.items(), desc="Phonemizing..."):
for path in tqdm(wavs, desc="Quantizing..."): for path in tqdm(wavs, desc="Quantizing..."):
qnt = valle_quantize(path, device=device) qnt = valle_quantize(path, device=device)
torch.save(qnt.cpu(), _replace_file_extension(path, ".qnt.pt")) torch.save(qnt.cpu(), _replace_file_extension(path, ".qnt.pt"))
"""

View File

@ -1,15 +1,24 @@
"""
# Handles processing `facebookresearch/libri-light`'s unlabeled audio into a friendlier hierarchy
"""
import os import os
import json import json
input_dataset = "duplicate" datasets = ["small", "medium", "large", "duplicate"]
output_dataset = "LibriLight-4K" output_dataset = "LibriLight-4K"
for input_dataset in datasets:
if not os.path.isdir(f'./{input_dataset}/'):
continue
for speaker_id in os.listdir(f'./{input_dataset}/'): for speaker_id in os.listdir(f'./{input_dataset}/'):
if not os.path.isdir(f'./{input_dataset}/{speaker_id}/'): if not os.path.isdir(f'./{input_dataset}/{speaker_id}/'):
continue continue
for book_name in os.listdir(f'./{input_dataset}/{speaker_id}/'):
for book_name in os.listdir(f'./{input_dataset}/{speaker_id}/'):
subid = 0 subid = 0
for filename in os.listdir(f'./{input_dataset}/{speaker_id}/{book_name}'): for filename in os.listdir(f'./{input_dataset}/{speaker_id}/{book_name}'):
if filename[-5:] != ".json": if filename[-5:] != ".json":
continue continue

View File

@ -1,21 +0,0 @@
import os
import json
input_dataset = "LibriTTS_R"
output_dataset = "LibriTTS-Train"
for dataset_name in os.listdir(f'./{input_dataset}/'):
if not os.path.isdir(f'./{input_dataset}/{dataset_name}/'):
continue
for speaker_id in os.listdir(f'./{input_dataset}/{dataset_name}/'):
if not os.path.isdir(f'./{input_dataset}/{dataset_name}/{speaker_id}'):
continue
for book_id in os.listdir(f'./{input_dataset}/{dataset_name}/{speaker_id}'):
if not os.path.isdir(f'./{input_dataset}/{dataset_name}/{speaker_id}/{book_id}'):
continue
for filename in os.listdir(f'./{input_dataset}/{dataset_name}/{speaker_id}/{book_id}'):
if filename[-4:] != ".wav":
continue
os.makedirs(f'./{output_dataset}/{speaker_id}/', exist_ok=True)
os.rename(f'./{input_dataset}/{dataset_name}/{speaker_id}/{book_id}/{filename}', f'./{output_dataset}/{speaker_id}/{filename}')

View File

@ -1,6 +1,13 @@
"""
# Handles processing audio provided through --input-audio of adequately annotated transcriptions provided through --input-metadata (through transcribe.py)
# Outputs NumPy objects containing quantized audio and adequate metadata for use of loading in the trainer through --output-dataset
"""
import os import os
import json import json
import argparse
import torch import torch
import torchaudio
import numpy as np import numpy as np
from tqdm.auto import tqdm from tqdm.auto import tqdm
@ -8,74 +15,146 @@ from pathlib import Path
from vall_e.config import cfg from vall_e.config import cfg
# things that could be args def pad(num, zeroes):
cfg.sample_rate = 24_000 return str(num).zfill(zeroes+1)
cfg.audio_backend = "encodec"
"""
cfg.inference.weight_dtype = "bfloat16"
cfg.inference.dtype = torch.bfloat16
cfg.inference.amp = True
"""
from vall_e.emb.g2p import encode as valle_phonemize def process_items( items, stride=0, stride_offset=0 ):
from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension items = sorted( items )
return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ]
def process(
audio_backend="encodec",
input_audio="LibriTTS_R",
output_dataset="training",
raise_exceptions=False,
stride=0,
stride_offset=0,
slice="auto",
device="cuda",
dtype="float16",
amp=False,
):
# encodec / vocos
if audio_backend in ["encodec", "vocos"]:
audio_extension = ".enc" audio_extension = ".enc"
if cfg.audio_backend == "dac": cfg.sample_rate = 24_000
cfg.model.resp_levels = 8
elif audio_backend == "dac":
audio_extension = ".dac" audio_extension = ".dac"
cfg.sample_rate = 44_100
cfg.model.resp_levels = 9
elif cfg.audio_backend == "audiodec": elif cfg.audio_backend == "audiodec":
sample_rate = 48_000
audio_extension = ".dec" audio_extension = ".dec"
cfg.model.resp_levels = 8 # ?
else:
raise Exception(f"Unknown audio backend: {audio_backend}")
input_dataset = "LibriTTS_R" # prepare from args
output_dataset = f"LibriTTS-Train-{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" cfg.audio_backend = audio_backend # "encodec"
device = "cuda" cfg.inference.weight_dtype = dtype # "bfloat16"
cfg.inference.amp = amp # False
# import after because we've overriden the config above
# need to validate if this is even necessary anymore
from vall_e.emb.g2p import encode as phonemize
from vall_e.emb.qnt import encode as quantize, _replace_file_extension
output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training"
language_map = {} # k = group, v = language
ignore_groups = [] # skip these groups
ignore_speakers = [] # skip these speakers
only_groups = [] # only process these groups
only_speakers = [] # only process these speakers
always_slice_groups = [] # always slice from this group
missing = {
"transcription": [],
"audio": []
}
dataset = []
# Layout: ./LibriTTS_R/train-clean-100/103/1241
for group_name in sorted(os.listdir(f'./{input_audio}/')):
if not os.path.isdir(f'./{input_audio}/{group_name}/'):
print("Is not dir:", f'./{input_audio}/{group_name}/')
continue
if group_name in ignore_groups:
continue
if only_groups and group_name not in only_groups:
continue
for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{group_name}/'), stride=stride, stride_offset=stride_offset), desc=f"Processing speaker in {group_name}"):
if not os.path.isdir(f'./{input_audio}/{group_name}/{speaker_id}'):
print("Is not dir:", f'./{input_audio}/{group_name}/{speaker_id}')
continue
if speaker_id in ignore_speakers:
continue
if only_speakers and speaker_id not in only_speakers:
continue
os.makedirs(f'./{output_dataset}/{group_name}/{speaker_id}/', exist_ok=True)
if f'{group_name}/{speaker_id}' not in dataset:
dataset.append(f'{group_name}/{speaker_id}')
txts = [] txts = []
wavs = [] wavs = []
for dataset_name in os.listdir(f'./{input_dataset}/'): for book_id in os.listdir(f'./{input_audio}/{dataset_name}/{speaker_id}'):
if not os.path.isdir(f'./{input_dataset}/{dataset_name}/'): if not os.path.isdir(f'./{input_audio}/{group_name}/{speaker_id}/{book_id}'):
print("Is not dir:", f'./{input_audio}/{group_name}/{speaker_id}/{book_id}')
continue continue
for speaker_id in tqdm(os.listdir(f'./{input_dataset}/{dataset_name}/'), desc="Processing speaker"): for filename in os.listdir(f'./{input_audio}/{dataset_name}/{speaker_id}/{book_id}'):
if not os.path.isdir(f'./{input_dataset}/{dataset_name}/{speaker_id}'): inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{book_id}/{filename}')
continue if not inpath.exists():
missing["audio"].append(str(inpath))
os.makedirs(f'./{output_dataset}/{speaker_id}/', exist_ok=True) extension = os.path.splitext(filename)[-1][1:]
for book_id in os.listdir(f'./{input_dataset}/{dataset_name}/{speaker_id}'): fname = filename.replace(f'.{extension}', "")
if not os.path.isdir(f'./{input_dataset}/{dataset_name}/{speaker_id}/{book_id}'):
continue
for filename in os.listdir(f'./{input_dataset}/{dataset_name}/{speaker_id}/{book_id}'):
# os.rename(f'./{input_dataset}/{dataset_name}/{speaker_id}/{book_id}/{filename}', f'./{output_dataset}/{speaker_id}/{filename}')
inpath = Path(f'./{input_dataset}/{dataset_name}/{speaker_id}/{book_id}/{filename}') waveform, sample_rate = None, None
outpath = Path(f'./{output_dataset}/{speaker_id}/{filename}') language = "en"
if ".wav" in filename: # and not _replace_file_extension(outpath, ".dac").exists(): outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}')
txts.append((
inpath,
outpath
))
for paths in tqdm(txts, desc="Processing..."):
inpath, outpath = paths
try:
if _replace_file_extension(outpath, ".dac").exists() and _replace_file_extension(outpath, ".json").exists():
data = json.loads(open(_replace_file_extension(outpath, ".json"), 'r', encoding='utf-8').read())
qnt = np.load(_replace_file_extension(outpath, audio_extension), allow_pickle=True)
if not isinstance(data["phonemes"], str):
data["phonemes"] = "".join(data["phonemes"])
for k, v in data.items():
qnt[()]['metadata'][k] = v
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), qnt)
else:
text = open(_replace_file_extension(inpath, ".original.txt"), "r", encoding="utf-8").read() text = open(_replace_file_extension(inpath, ".original.txt"), "r", encoding="utf-8").read()
phones = valle_phonemize(text) if len(text) == 0:
qnt = valle_quantize(_replace_file_extension(inpath, ".wav"), device=device) continue
if _replace_file_extension(outpath, audio_extension).exists():
continue
if waveform is None:
waveform, sample_rate = torchaudio.load(inpath)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
wavs.append((
outpath,
text,
language,
waveform,
sample_rate
))
if len(wavs) > 0:
for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"):
try:
outpath, text, language, waveform, sample_rate = job
phones = phonemize(text, language=language)
qnt = quantize(waveform, sr=sample_rate, device=device)
if cfg.audio_backend == "dac": if cfg.audio_backend == "dac":
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
@ -92,20 +171,59 @@ for paths in tqdm(txts, desc="Processing..."):
"text": text.strip(), "text": text.strip(),
"phonemes": "".join(phones), "phonemes": "".join(phones),
"language": "en", "language": language,
}, },
}) })
else: else:
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.cpu().numpy().astype(np.uint16), "codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": { "metadata": {
"original_length": qnt.shape[-1] / 75.0, "original_length": waveform.shape[-1],
"sample_rate": cfg.sample_rate, "sample_rate": sample_rate,
"text": text.strip(), "text": text.strip(),
"phonemes": "".join(phones), "phonemes": "".join(phones),
"language": "en", "language": language,
}, },
}) })
except Exception as e: except Exception as e:
tqdm.write(f"Failed to process: {paths}: {e}") print(f"Failed to quantize: {outpath}:", e)
if raise_exceptions:
raise e
continue
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
open(f"./{output_dataset}/dataset.json", 'w', encoding='utf-8').write(json.dumps(dataset))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--audio-backend", type=str, default="encodec")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--amp", action="store_true")
parser.add_argument("--input-audio", type=str, default="LibriTTS_R")
parser.add_argument("--output-dataset", type=str, default="training/dataset")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--raise-exceptions", action="store_true")
parser.add_argument("--stride", type=int, default=0)
parser.add_argument("--stride-offset", type=int, default=0)
parser.add_argument("--slice", type=str, default="auto")
args = parser.parse_args()
process(
audio_backend=args.audio_backend,
input_audio=args.input_audio,
output_dataset=args.output_dataset,
raise_exceptions=args.raise_exceptions,
stride=args.stride,
stride_offset=args.stride_offset,
slice=args.slice,
device=args.device,
dtype=args.dtype,
amp=args.amp,
)
if __name__ == "__main__":
main()

View File

@ -1,3 +1,7 @@
"""
# Helper script to grab all phonemes through parsed dataset metadata to find the "best" tokenizer dict
"""
import os import os
import json import json
import torch import torch

View File

@ -59,6 +59,7 @@ def process(
cfg.inference.amp = amp # False cfg.inference.amp = amp # False
# import after because we've overriden the config above # import after because we've overriden the config above
# need to validate if this is even necessary anymore
from .g2p import encode as phonemize from .g2p import encode as phonemize
from .qnt import encode as quantize, _replace_file_extension from .qnt import encode as quantize, _replace_file_extension
@ -275,8 +276,8 @@ def process(
raise e raise e
continue continue
open("./missing.json", 'w', encoding='utf-8').write(json.dumps(missing)) open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
open("./dataset_list.json", 'w', encoding='utf-8').write(json.dumps(dataset)) open(f"./{output_dataset}/dataset.json", 'w', encoding='utf-8').write(json.dumps(dataset))
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()