""" # Handles processing audio provided through --input-audio of adequately annotated transcriptions provided through --input-metadata (through transcribe.py) # Outputs NumPy objects containing quantized audio and adequate metadata for use of loading in the trainer through --output-dataset """ import os import json import argparse import torch import torchaudio import numpy as np from tqdm.auto import tqdm from pathlib import Path from ..config import cfg def pad(num, zeroes): return str(num).zfill(zeroes+1) def process_items( items, stride=0, stride_offset=0 ): items = sorted( items ) return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ] def process( audio_backend="encodec", input_audio="voices", input_metadata="metadata", output_dataset="training", raise_exceptions=False, stride=0, stride_offset=0, slice="auto", device="cuda", dtype="float16", amp=False, ): # encodec / vocos if audio_backend in ["encodec", "vocos"]: audio_extension = ".enc" cfg.sample_rate = 24_000 cfg.model.resp_levels = 8 elif audio_backend == "dac": audio_extension = ".dac" cfg.sample_rate = 44_100 cfg.model.resp_levels = 9 elif cfg.audio_backend == "audiodec": sample_rate = 48_000 audio_extension = ".dec" cfg.model.resp_levels = 8 # ? else: raise Exception(f"Unknown audio backend: {audio_backend}") # prepare from args cfg.audio_backend = audio_backend # "encodec" cfg.inference.weight_dtype = dtype # "bfloat16" cfg.inference.amp = amp # False # import after because we've overriden the config above # need to validate if this is even necessary anymore from .g2p import encode as phonemize from .qnt import encode as quantize, _replace_file_extension output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training" language_map = {} # k = group, v = language ignore_groups = [] # skip these groups ignore_speakers = [] # skip these speakers only_groups = [] # only process these groups only_speakers = [] # only process these speakers always_slice_groups = [] # always slice from this group missing = { "transcription": [], "audio": [] } dataset = [] for group_name in sorted(os.listdir(f'./{input_audio}/')): if not os.path.isdir(f'./{input_audio}/{group_name}/'): print("Is not dir:", f'./{input_audio}/{group_name}/') continue if group_name in ignore_groups: continue if only_groups and group_name not in only_groups: continue for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{group_name}/'), stride=stride, stride_offset=stride_offset), desc=f"Processing speaker in {group_name}"): if not os.path.isdir(f'./{input_audio}/{group_name}/{speaker_id}'): print("Is not dir:", f'./{input_audio}/{group_name}/{speaker_id}') continue if speaker_id in ignore_speakers: continue if only_speakers and speaker_id not in only_speakers: continue os.makedirs(f'./{output_dataset}/{group_name}/{speaker_id}/', exist_ok=True) if speaker_id == "Noise": for filename in sorted(os.listdir(f'./{input_audio}/{group_name}/{speaker_id}/')): inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}') outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{filename}') if _replace_file_extension(outpath, audio_extension).exists(): continue waveform, sample_rate = torchaudio.load(inpath) qnt = quantize(waveform, sr=sample_rate, device=device) if cfg.audio_backend == "dac": np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { "codes": qnt.codes.cpu().numpy().astype(np.uint16), "metadata": { "original_length": qnt.original_length, "sample_rate": qnt.sample_rate, "input_db": qnt.input_db.cpu().numpy().astype(np.float32), "chunk_length": qnt.chunk_length, "channels": qnt.channels, "padding": qnt.padding, "dac_version": "1.0.0", }, }) else: np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { "codes": qnt.cpu().numpy().astype(np.uint16), "metadata": { "original_length": waveform.shape[-1], "sample_rate": sample_rate, }, }) continue metadata_path = Path(f'./{input_metadata}/{group_name}/{speaker_id}/whisper.json') if not metadata_path.exists(): missing["transcription"].append(str(metadata_path)) continue try: metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read()) except Exception as e: missing["transcription"].append(str(metadata_path)) continue if f'{group_name}/{speaker_id}' not in dataset: dataset.append(f'{group_name}/{speaker_id}') txts = [] wavs = [] use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or group_name in always_slice_groups for filename in sorted(metadata.keys()): inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}') if not inpath.exists(): missing["audio"].append(str(inpath)) continue extension = os.path.splitext(filename)[-1][1:] fname = filename.replace(f'.{extension}', "") waveform, sample_rate = None, None language = language_map[group_name] if group_name in language_map else (metadata[filename]["language"] if "language" in metadata[filename] else "en") if len(metadata[filename]["segments"]) == 0 or not use_slices: outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}') text = metadata[filename]["text"] if len(text) == 0: continue if _replace_file_extension(outpath, audio_extension).exists(): continue if waveform is None: waveform, sample_rate = torchaudio.load(inpath) if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) wavs.append(( outpath, text, language, waveform, sample_rate )) else: i = 0 for segment in metadata[filename]["segments"]: id = pad(i, 4) i = i + 1 outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}_{id}.{extension}') text = segment["text"] if len(text) == 0: continue if _replace_file_extension(outpath, audio_extension).exists(): continue if waveform is None: waveform, sample_rate = torchaudio.load(inpath) if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) start = int(segment['start'] * sample_rate) end = int(segment['end'] * sample_rate) if start < 0: start = 0 if end >= waveform.shape[-1]: end = waveform.shape[-1] - 1 if end - start < 0: continue wavs.append(( outpath, text, language, waveform[:, start:end], sample_rate )) if len(wavs) > 0: for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"): try: outpath, text, language, waveform, sample_rate = job phones = phonemize(text, language=language) qnt = quantize(waveform, sr=sample_rate, device=device) if cfg.audio_backend == "dac": np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { "codes": qnt.codes.cpu().numpy().astype(np.uint16), "metadata": { "original_length": qnt.original_length, "sample_rate": qnt.sample_rate, "input_db": qnt.input_db.cpu().numpy().astype(np.float32), "chunk_length": qnt.chunk_length, "channels": qnt.channels, "padding": qnt.padding, "dac_version": "1.0.0", "text": text.strip(), "phonemes": "".join(phones), "language": language, }, }) else: np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { "codes": qnt.cpu().numpy().astype(np.uint16), "metadata": { "original_length": waveform.shape[-1], "sample_rate": sample_rate, "text": text.strip(), "phonemes": "".join(phones), "language": language, }, }) except Exception as e: print(f"Failed to quantize: {outpath}:", e) if raise_exceptions: raise e continue open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing)) open(f"./{output_dataset}/dataset.json", 'w', encoding='utf-8').write(json.dumps(dataset)) def main(): parser = argparse.ArgumentParser() parser.add_argument("--audio-backend", type=str, default="encodec") parser.add_argument("--dtype", type=str, default="bfloat16") parser.add_argument("--amp", action="store_true") parser.add_argument("--input-audio", type=str, default="voices") parser.add_argument("--input-metadata", type=str, default="training/metadata") parser.add_argument("--output-dataset", type=str, default="training/dataset") parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--raise-exceptions", action="store_true") parser.add_argument("--stride", type=int, default=0) parser.add_argument("--stride-offset", type=int, default=0) parser.add_argument("--slice", type=str, default="auto") args = parser.parse_args() process( audio_backend=args.audio_backend, input_audio=args.input_audio, input_metadata=args.input_metadata, output_dataset=args.output_dataset, raise_exceptions=args.raise_exceptions, stride=args.stride, stride_offset=args.stride_offset, slice=args.slice, device=args.device, dtype=args.dtype, amp=args.amp, ) if __name__ == "__main__": main()