From 7cdfa3dc0c5e89b1a1fe7007ef8ea3df206a7257 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 5 Aug 2024 15:59:25 -0500 Subject: [PATCH] updated process_datasets.py, added argparsing so I can mostly stop manually editing things, and some other cleanup --- scripts/process_dataset.py | 402 +++++++++++++++++++++---------------- vall_e/models/base.py | 11 +- 2 files changed, 233 insertions(+), 180 deletions(-) diff --git a/scripts/process_dataset.py b/scripts/process_dataset.py index 792d8bd..19cec86 100644 --- a/scripts/process_dataset.py +++ b/scripts/process_dataset.py @@ -1,191 +1,104 @@ import os import json +import argparse import torch import torchaudio import numpy as np + from tqdm.auto import tqdm from pathlib import Path from vall_e.config import cfg -# things that could be args -cfg.sample_rate = 24_000 -cfg.audio_backend = "encodec" -""" -cfg.inference.weight_dtype = "bfloat16" -cfg.inference.dtype = torch.bfloat16 -cfg.inference.amp = True -""" - -from vall_e.emb.g2p import encode as valle_phonemize -from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension - -input_audio = "voices" -input_metadata = "metadata" -output_dataset = f"training-{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" -device = "cuda" - -audio_extension = ".enc" -if cfg.audio_backend == "dac": - audio_extension = ".dac" -elif cfg.audio_backend == "audiodec": - audio_extension = ".dec" - -slice = "auto" -missing = { - "transcription": [], - "audio": [] -} -dataset = [] def pad(num, zeroes): return str(num).zfill(zeroes+1) -for dataset_name in sorted(os.listdir(f'./{input_audio}/')): - if not os.path.isdir(f'./{input_audio}/{dataset_name}/'): - print("Is not dir:", f'./{input_audio}/{dataset_name}/') - continue +def process_items( items, stride=0 ): + items = sorted( items ) + return items if stride == 0 else [ item for i, item in enumerate( items ) if i % stride == 0 ] - for speaker_id in tqdm(sorted(os.listdir(f'./{input_audio}/{dataset_name}/')), desc=f"Processing speaker in {dataset_name}"): - if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'): - print("Is not dir:", f'./{input_audio}/{dataset_name}/{speaker_id}') - continue - - os.makedirs(f'./{output_dataset}/{dataset_name}/{speaker_id}/', exist_ok=True) +def process_dataset( args ): + # encodec / vocos - if speaker_id == "Noise": - for filename in sorted(os.listdir(f'./{input_audio}/{dataset_name}/{speaker_id}/')): - inpath = Path(f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}') - outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{filename}') + if args.audio_backend in ["encodec", "vocos"]: + audio_extension = ".enc" + cfg.sample_rate = 24_000 + cfg.model.resp_levels = 8 + elif args.audio_backend == "dac": + audio_extension = ".dac" + cfg.sample_rate = 44_100 + cfg.model.resp_levels = 9 + elif cfg.audio_backend == "audiodec": + sample_rate = 48_000 + audio_extension = ".dec" + cfg.model.resp_levels = 8 # ? + else: + raise Exception(f"Unknown audio backend: {args.audio_backend}") - if _replace_file_extension(outpath, audio_extension).exists(): - continue + # prepare from args + cfg.audio_backend = args.audio_backend # "encodec" + cfg.inference.weight_dtype = args.dtype # "bfloat16" + cfg.inference.amp = args.amp # False - waveform, sample_rate = torchaudio.load(inpath) - qnt = valle_quantize(waveform, sr=sample_rate, device=device) + # import after because we've overriden the config above + from vall_e.emb.g2p import encode as valle_phonemize + from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension - if cfg.audio_backend == "dac": - np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { - "codes": qnt.codes.cpu().numpy().astype(np.uint16), - "metadata": { - "original_length": qnt.original_length, - "sample_rate": qnt.sample_rate, - - "input_db": qnt.input_db.cpu().numpy().astype(np.float32), - "chunk_length": qnt.chunk_length, - "channels": qnt.channels, - "padding": qnt.padding, - "dac_version": "1.0.0", - }, - }) - else: - np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { - "codes": qnt.cpu().numpy().astype(np.uint16), - "metadata": { - "original_length": waveform.shape[-1], - "sample_rate": sample_rate, - }, - }) + input_audio = args.input_audio # "voice"" + input_metadata = args.input_metadata # "metadata" + output_group = f"{args.output_group}-{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training" + device = args.device # "cuda" + raise_exceptions = args.raise_exceptions # False + stride = args.stride # 0 + slice = args.slice # "auto" - continue - - metadata_path = Path(f'./{input_metadata}/{dataset_name}/{speaker_id}/whisper.json') - if not metadata_path.exists(): - missing["transcription"].append(str(metadata_path)) + language_map = {} # k = group, v = language + + ignore_groups = [] # skip these groups + ignore_speakers = [] # skip these speakers + + only_groups = [] # only process these groups + only_speakers = [] # only process these speakers + + always_slice_groups = [] # always slice from this group + + missing = { + "transcription": [], + "audio": [] + } + dataset = [] + + for group_name in sorted(os.listdir(f'./{input_audio}/')): + if not os.path.isdir(f'./{input_audio}/{group_name}/'): + print("Is not dir:", f'./{input_audio}/{group_name}/') continue - try: - metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read()) - except Exception as e: - missing["transcription"].append(str(metadata_path)) + if group_name in ignore_groups: + continue + if only_groups and group_name not in only_groups: continue - if f'{dataset_name}/{speaker_id}' not in dataset: - dataset.append(f'{dataset_name}/{speaker_id}') - - txts = [] - wavs = [] - - use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or dataset_name in ["LibriVox", "Audiobooks"] - - for filename in sorted(metadata.keys()): - inpath = Path(f'./{input_audio}/{dataset_name}/{speaker_id}/{filename}') - if not inpath.exists(): - missing["audio"].append(str(inpath)) + for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{group_name}/'), stride=stride), desc=f"Processing speaker in {group_name}"): + if not os.path.isdir(f'./{input_audio}/{group_name}/{speaker_id}'): + print("Is not dir:", f'./{input_audio}/{group_name}/{speaker_id}') continue - extension = os.path.splitext(filename)[-1][1:] - fname = filename.replace(f'.{extension}', "") + if speaker_id in ignore_speakers: + continue + if only_speakers and speaker_id not in only_speakers: + continue - waveform, sample_rate = None, None - language = metadata[filename]["language"] if "language" in metadata[filename] else "en" + os.makedirs(f'./{output_group}/{group_name}/{speaker_id}/', exist_ok=True) - if len(metadata[filename]["segments"]) == 0 or not use_slices: - outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}.{extension}') - text = metadata[filename]["text"] - - if len(text) == 0: - continue - - if _replace_file_extension(outpath, audio_extension).exists(): - continue - - if waveform is None: - waveform, sample_rate = torchaudio.load(inpath) - if waveform.shape[0] > 1: - waveform = torch.mean(waveform, dim=0, keepdim=True) - - wavs.append(( - outpath, - text, - language, - waveform, - sample_rate - )) - else: - i = 0 - for segment in metadata[filename]["segments"]: - id = pad(i, 4) - i = i + 1 - - outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/{fname}_{id}.{extension}') - text = segment["text"] - - if len(text) == 0: - continue + if speaker_id == "Noise": + for filename in sorted(os.listdir(f'./{input_audio}/{group_name}/{speaker_id}/')): + inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}') + outpath = Path(f'./{output_group}/{group_name}/{speaker_id}/{filename}') if _replace_file_extension(outpath, audio_extension).exists(): continue - if waveform is None: - waveform, sample_rate = torchaudio.load(inpath) - if waveform.shape[0] > 1: - waveform = torch.mean(waveform, dim=0, keepdim=True) - - start = int(segment['start'] * sample_rate) - end = int(segment['end'] * sample_rate) - - if start < 0: - start = 0 - if end >= waveform.shape[-1]: - end = waveform.shape[-1] - 1 - - if end - start < 0: - continue - - wavs.append(( - outpath, - text, - language, - waveform[:, start:end], - sample_rate - )) - - if len(wavs) > 0: - for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"): - try: - outpath, text, language, waveform, sample_rate = job - - phones = valle_phonemize( text, language=language ) + waveform, sample_rate = torchaudio.load(inpath) qnt = valle_quantize(waveform, sr=sample_rate, device=device) if cfg.audio_backend == "dac": @@ -200,10 +113,6 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')): "channels": qnt.channels, "padding": qnt.padding, "dac_version": "1.0.0", - - "text": text.strip(), - "phonemes": "".join(phones), - "language": language, }, }) else: @@ -212,15 +121,168 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')): "metadata": { "original_length": waveform.shape[-1], "sample_rate": sample_rate, - - "text": text.strip(), - "phonemes": "".join(phones), - "language": language, }, }) - except Exception as e: - print(f"Failed to quantize: {outpath}:", e) - continue -open("./missing.json", 'w', encoding='utf-8').write(json.dumps(missing)) -open("./dataset_list.json", 'w', encoding='utf-8').write(json.dumps(dataset)) \ No newline at end of file + continue + + metadata_path = Path(f'./{input_metadata}/{group_name}/{speaker_id}/whisper.json') + if not metadata_path.exists(): + missing["transcription"].append(str(metadata_path)) + continue + + try: + metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read()) + except Exception as e: + missing["transcription"].append(str(metadata_path)) + continue + + if f'{group_name}/{speaker_id}' not in dataset: + dataset.append(f'{group_name}/{speaker_id}') + + txts = [] + wavs = [] + + use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or group_name in always_slice_groups + + for filename in sorted(metadata.keys()): + inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}') + if not inpath.exists(): + missing["audio"].append(str(inpath)) + continue + + extension = os.path.splitext(filename)[-1][1:] + fname = filename.replace(f'.{extension}', "") + + waveform, sample_rate = None, None + language = language_map[group_name] if group_name in language_map else (metadata[filename]["language"] if "language" in metadata[filename] else "en") + + if len(metadata[filename]["segments"]) == 0 or not use_slices: + outpath = Path(f'./{output_group}/{group_name}/{speaker_id}/{fname}.{extension}') + text = metadata[filename]["text"] + + if len(text) == 0: + continue + + if _replace_file_extension(outpath, audio_extension).exists(): + continue + + if waveform is None: + waveform, sample_rate = torchaudio.load(inpath) + if waveform.shape[0] > 1: + waveform = torch.mean(waveform, dim=0, keepdim=True) + + wavs.append(( + outpath, + text, + language, + waveform, + sample_rate + )) + else: + i = 0 + for segment in metadata[filename]["segments"]: + id = pad(i, 4) + i = i + 1 + + outpath = Path(f'./{output_group}/{group_name}/{speaker_id}/{fname}_{id}.{extension}') + text = segment["text"] + + if len(text) == 0: + continue + + if _replace_file_extension(outpath, audio_extension).exists(): + continue + + if waveform is None: + waveform, sample_rate = torchaudio.load(inpath) + if waveform.shape[0] > 1: + waveform = torch.mean(waveform, dim=0, keepdim=True) + + start = int(segment['start'] * sample_rate) + end = int(segment['end'] * sample_rate) + + if start < 0: + start = 0 + if end >= waveform.shape[-1]: + end = waveform.shape[-1] - 1 + + if end - start < 0: + continue + + wavs.append(( + outpath, + text, + language, + waveform[:, start:end], + sample_rate + )) + + if len(wavs) > 0: + for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"): + try: + outpath, text, language, waveform, sample_rate = job + + phones = valle_phonemize(text, language=language) + qnt = valle_quantize(waveform, sr=sample_rate, device=device) + + + if cfg.audio_backend == "dac": + np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { + "codes": qnt.codes.cpu().numpy().astype(np.uint16), + "metadata": { + "original_length": qnt.original_length, + "sample_rate": qnt.sample_rate, + + "input_db": qnt.input_db.cpu().numpy().astype(np.float32), + "chunk_length": qnt.chunk_length, + "channels": qnt.channels, + "padding": qnt.padding, + "dac_version": "1.0.0", + + "text": text.strip(), + "phonemes": "".join(phones), + "language": language, + }, + }) + else: + np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { + "codes": qnt.cpu().numpy().astype(np.uint16), + "metadata": { + "original_length": waveform.shape[-1], + "sample_rate": sample_rate, + + "text": text.strip(), + "phonemes": "".join(phones), + "language": language, + }, + }) + except Exception as e: + print(f"Failed to quantize: {outpath}:", e) + if raise_exceptions: + raise e + continue + + open("./missing.json", 'w', encoding='utf-8').write(json.dumps(missing)) + open("./dataset_list.json", 'w', encoding='utf-8').write(json.dumps(dataset)) + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument("--audio-backend", type=str, default="encodec") + parser.add_argument("--dtype", type=str, default="bfloat16") + parser.add_argument("--amp", action="store_true") + parser.add_argument("--input-audio", type=str, default="voices") + parser.add_argument("--input-metadata", type=str, default="metadata") + parser.add_argument("--output_group", type=str, default="training") + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--raise-exceptions", action="store_true") + parser.add_argument("--stride", type=int, default=0) + parser.add_argument("--slice", type=str, default="auto") + + args = parser.parse_args() + + process_dataset( args ) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 2a40f28..74b07b9 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -813,7 +813,7 @@ class Base(nn.Module): inputs_embeds=x, past_key_values=state, position_ids=position_ids, - use_cache=True, + use_cache=not self.training, # return_dict=True, ) if self.n_experts > 1 and self.training: @@ -1350,15 +1350,6 @@ class Base(nn.Module): x, m = list_to_tensor(x_list) training = self.training - # yes, there's a better way. - """ - training = False - for batch_index, batch in enumerate(inputs): - for name, input in batch: - if name == "targ": - training = True - """ - device = x.device batch_size = len(x_list)