diff --git a/scripts/process_emilia.py b/scripts/process_emilia.py new file mode 100644 index 0000000..ded3323 --- /dev/null +++ b/scripts/process_emilia.py @@ -0,0 +1,240 @@ +""" +# Handles processing audio provided through --input-audio of adequately annotated transcriptions provided through --input-metadata (through transcribe.py) +# Outputs NumPy objects containing quantized audio and adequate metadata for use of loading in the trainer through --output-dataset +""" + +import os +import json +import argparse +import torch +import torchaudio +import numpy as np + +from tqdm.auto import tqdm +from pathlib import Path + +from vall_e.config import cfg + +def pad(num, zeroes): + return str(num).zfill(zeroes+1) + +def process_items( items, stride=0, stride_offset=0 ): + items = sorted( items ) + return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ] + +def process( + audio_backend="encodec", + input_audio="Emilia", + output_dataset="training", + raise_exceptions=False, + stride=0, + stride_offset=0, + slice="auto", + + device="cuda", + dtype="float16", + amp=False, + ): + # encodec / vocos + + if audio_backend in ["encodec", "vocos"]: + audio_extension = ".enc" + cfg.sample_rate = 24_000 + cfg.model.resp_levels = 8 + elif audio_backend == "dac": + audio_extension = ".dac" + cfg.sample_rate = 44_100 + cfg.model.resp_levels = 9 + elif cfg.audio_backend == "audiodec": + sample_rate = 48_000 + audio_extension = ".dec" + cfg.model.resp_levels = 8 # ? + else: + raise Exception(f"Unknown audio backend: {audio_backend}") + + # prepare from args + cfg.audio_backend = audio_backend # "encodec" + cfg.inference.weight_dtype = dtype # "bfloat16" + cfg.inference.amp = amp # False + + # import after because we've overriden the config above + # need to validate if this is even necessary anymore + from vall_e.emb.g2p import encode as phonemize + from vall_e.emb.qnt import encode as quantize, _replace_file_extension + + output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training" + + language_map = {} # k = group, v = language + + ignore_groups = [] # skip these groups + ignore_speakers = [] # skip these speakers + + only_groups = [] # only process these groups + only_speakers = [] # only process these speakers + + always_slice_groups = [] # always slice from this group + + missing = { + "transcription": [], + "audio": [] + } + dataset = [] + + # Layout: ./Emilia/JA/JA-B000000/JA_B00000_S00000_W000000.{json|mp3} + for language in sorted(os.listdir(f'./{input_audio}/')): + if not os.path.isdir(f'./{input_audio}/{language}/'): + print("Is not dir:", f'./{input_audio}/{language}/') + continue + + if language in ignore_groups: + continue + + if only_groups and language not in only_groups: + continue + + group_name = "Emilia" + + for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{language}/'), stride=stride, stride_offset=stride_offset), desc=f"Processing speaker in {language}"): + if not os.path.isdir(f'./{input_audio}/{language}/{speaker_id}'): + print("Is not dir:", f'./{input_audio}/{language}/{speaker_id}') + continue + + if speaker_id in ignore_speakers: + continue + if only_speakers and speaker_id not in only_speakers: + continue + + os.makedirs(f'./{output_dataset}/{group_name}/{speaker_id}/', exist_ok=True) + + if f'{group_name}/{speaker_id}' not in dataset: + dataset.append(f'{group_name}/{speaker_id}') + + txts = [] + wavs = [] + + for filename in os.listdir(f'./{input_audio}/{language}/{speaker_id}'): + if ".mp3" not in filename: + continue + + inpath = Path(f'./{input_audio}/{language}/{speaker_id}/{filename}') + jsonpath = _replace_file_extension(inpath, ".json") + if not inpath.exists() or not jsonpath.exists(): + missing["audio"].append(str(inpath)) + continue + + extension = os.path.splitext(filename)[-1][1:] + fname = filename.replace(f'.{extension}', "") + + waveform, sample_rate = None, None + + outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}') + metadata = json.load(open(jsonpath, "r", encoding="utf-8")) + + if "text" not in metadata: + continue + + if _replace_file_extension(outpath, audio_extension).exists(): + continue + + text = metadata["text"] + + if waveform is None: + waveform, sample_rate = torchaudio.load(inpath) + if waveform.shape[0] > 1: + waveform = torch.mean(waveform, dim=0, keepdim=True) + + wavs.append(( + outpath, + text, + language.lower(), + waveform, + sample_rate + )) + + if len(wavs) > 0: + for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"): + try: + outpath, text, language, waveform, sample_rate = job + + phones = phonemize(text, language=language) + qnt = quantize(waveform, sr=sample_rate, device=device) + + + if cfg.audio_backend == "dac": + np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { + "codes": qnt.codes.cpu().numpy().astype(np.uint16), + "metadata": { + "original_length": qnt.original_length, + "sample_rate": qnt.sample_rate, + + "input_db": qnt.input_db.cpu().numpy().astype(np.float32), + "chunk_length": qnt.chunk_length, + "channels": qnt.channels, + "padding": qnt.padding, + "dac_version": "1.0.0", + + "text": text.strip(), + "phonemes": "".join(phones), + "language": language, + }, + }) + else: + np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { + "codes": qnt.cpu().numpy().astype(np.uint16), + "metadata": { + "original_length": waveform.shape[-1], + "sample_rate": sample_rate, + + "text": text.strip(), + "phonemes": "".join(phones), + "language": language, + }, + }) + except Exception as e: + print(f"Failed to quantize: {outpath}:", e) + if raise_exceptions: + raise e + continue + + open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing)) + open(f"./{output_dataset}/dataset.json", 'w', encoding='utf-8').write(json.dumps(dataset)) + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument("--audio-backend", type=str, default="encodec") + parser.add_argument("--dtype", type=str, default="bfloat16") + parser.add_argument("--amp", action="store_true") + parser.add_argument("--input-audio", type=str, default="Emilia") + parser.add_argument("--output-dataset", type=str, default="training/dataset") + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--raise-exceptions", action="store_true") + parser.add_argument("--stride", type=int, default=0) + parser.add_argument("--stride-offset", type=int, default=0) + parser.add_argument("--slice", type=str, default="auto") + + args = parser.parse_args() + + # do some assumption magic + # to-do: find a nice way to spawn multiple processes where tqdm plays nicely + if args.device.isnumeric(): + args.stride = torch.cuda.device_count() + args.stride_offset = int(args.device) + args.device = f'cuda:{args.device}' + + process( + audio_backend=args.audio_backend, + input_audio=args.input_audio, + output_dataset=args.output_dataset, + raise_exceptions=args.raise_exceptions, + stride=args.stride, + stride_offset=args.stride_offset, + slice=args.slice, + + device=args.device, + dtype=args.dtype, + amp=args.amp, + ) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/vall_e/data.py b/vall_e/data.py index 8d674e1..bc687b9 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -12,6 +12,7 @@ import itertools from .config import cfg from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge_audio, decode_to_file, decode as decode_qnt, encode as encode_qnt, pad_codes_with_silence +from .emb.g2p import encode as encode_phns from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler from .utils.distributed import global_rank, local_rank, world_size from .utils.io import torch_save, torch_load @@ -1316,6 +1317,32 @@ def create_train_val_dataloader(): return train_dl, subtrain_dl, val_dl +def process_artifact_metadata( artifact ): + metadata = {} + + if "text" in artifact["metadata"]: + metadata["text"] = artifact["metadata"]["text"] + if "phonemes" in artifact["metadata"]: + metadata["phonemes"] = artifact["metadata"]["phonemes"] + if "language" in artifact["metadata"]: + metadata["language"] = artifact["metadata"]["language"] + if "original_length" in artifact["metadata"] and "sample_rate" in artifact["metadata"]: + metadata["duration"] = artifact["metadata"]["original_length"] / artifact["metadata"]["sample_rate"] + + # rephonemize if required + if "phonemes" not in metadata and "text" in metadata: + metadata["phonemes"] = encode_phns( metadata["text"], language=metadata["language"] if "language" in metadata["language"] else "en" ) + + # clean up phonemes from espeak + # for example: Sonnenküste Update => zˈɔnənkˌystə (en)ˈʌpdeɪt(de) + # to-do: regex replace /([a-z]{2})/ to "" + if "phonemes" in metadata: + metadata["phonemes"] = metadata["phonemes"].replace("(en)", "") + if "phonemes" in metadata and "language" in metadata: + metadata["phonemes"] = metadata["phonemes"].replace(f"({metadata['language']})", "") + + return metadata + # parse dataset into better to sample metadata def create_dataset_metadata( skip_existing=True ): symmap = get_phone_symmap() @@ -1369,18 +1396,10 @@ def create_dataset_metadata( skip_existing=True ): utterance_metadata = {} if audios: - # ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt - dac = np.load(quant_path, allow_pickle=True)[()] - qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16) + artifact = np.load(quant_path, allow_pickle=True)[()] + qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16) - if "text" in dac["metadata"]: - utterance_metadata["text"] = dac["metadata"]["text"] - if "phonemes" in dac["metadata"]: - utterance_metadata["phonemes"] = dac["metadata"]["phonemes"] - if "language" in dac["metadata"]: - utterance_metadata["language"] = dac["metadata"]["language"] - if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]: - utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"] + utterance_metadata = process_artifact_metadata( artifact ) for k, v in utterance_metadata.items(): metadata[id][k] = v @@ -1484,17 +1503,10 @@ def create_dataset_hdf5( skip_existing=True ): # audio if audios: - dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()] - qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16) + artifact = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()] + qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16) - if "text" in dac["metadata"]: - utterance_metadata["text"] = dac["metadata"]["text"] - if "phonemes" in dac["metadata"]: - utterance_metadata["phonemes"] = dac["metadata"]["phonemes"] - if "language" in dac["metadata"]: - utterance_metadata["language"] = dac["metadata"]["language"] - if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]: - utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"] + utterance_metadata = process_artifact_metadata( artifact ) if "audio" not in group: group.create_dataset('audio', data=qnt.numpy().astype(np.int16), compression='lzf')