From 9710b06b745aa896788690ca6d14582c5b6fcac3 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 6 Aug 2024 08:17:25 -0500 Subject: [PATCH] tweaks and things --- scripts/process_libritts.py | 13 +++++++++++-- vall_e/emb/process.py | 24 ++++++++++++++++-------- vall_e/emb/qnt.py | 20 ++++++++------------ vall_e/emb/transcribe.py | 22 ++++++++++++++++++++-- 4 files changed, 55 insertions(+), 24 deletions(-) diff --git a/scripts/process_libritts.py b/scripts/process_libritts.py index 250c2d7..7152ac9 100755 --- a/scripts/process_libritts.py +++ b/scripts/process_libritts.py @@ -119,8 +119,10 @@ def process( continue inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{book_id}/{filename}') - if not inpath.exists(): + textpath = _replace_file_extension(inpath, ".original.txt") + if not inpath.exists() or not textpath.exists(): missing["audio"].append(str(inpath)) + continue extension = os.path.splitext(filename)[-1][1:] fname = filename.replace(f'.{extension}', "") @@ -129,7 +131,7 @@ def process( language = "en" outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}') - text = open(_replace_file_extension(inpath, ".original.txt"), "r", encoding="utf-8").read() + text = open(textpath, "r", encoding="utf-8").read() if len(text) == 0: continue @@ -214,6 +216,13 @@ def main(): args = parser.parse_args() + # do some assumption magic + # to-do: find a nice way to spawn multiple processes where tqdm plays nicely + if args.device.isnumeric(): + args.stride = torch.cuda.device_count() + args.stride_offset = int(args.device) + args.device = f'cuda:{args.device}' + process( audio_backend=args.audio_backend, input_audio=args.input_audio, diff --git a/vall_e/emb/process.py b/vall_e/emb/process.py index 3e7b767..6b01536 100644 --- a/vall_e/emb/process.py +++ b/vall_e/emb/process.py @@ -15,6 +15,10 @@ from pathlib import Path from ..config import cfg +# need to validate if this is safe to import before modifying the config +from .g2p import encode as phonemize +from .qnt import encode as quantize, _replace_file_extension + def pad(num, zeroes): return str(num).zfill(zeroes+1) @@ -58,11 +62,6 @@ def process( cfg.inference.weight_dtype = dtype # "bfloat16" cfg.inference.amp = amp # False - # import after because we've overriden the config above - # need to validate if this is even necessary anymore - from .g2p import encode as phonemize - from .qnt import encode as quantize, _replace_file_extension - output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training" language_map = {} # k = group, v = language @@ -272,6 +271,7 @@ def process( }) except Exception as e: print(f"Failed to quantize: {outpath}:", e) + torchaudio.save( waveform.cpu ) if raise_exceptions: raise e continue @@ -283,19 +283,27 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument("--audio-backend", type=str, default="encodec") - parser.add_argument("--dtype", type=str, default="bfloat16") - parser.add_argument("--amp", action="store_true") parser.add_argument("--input-audio", type=str, default="voices") parser.add_argument("--input-metadata", type=str, default="training/metadata") parser.add_argument("--output-dataset", type=str, default="training/dataset") - parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--raise-exceptions", action="store_true") parser.add_argument("--stride", type=int, default=0) parser.add_argument("--stride-offset", type=int, default=0) parser.add_argument("--slice", type=str, default="auto") + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--dtype", type=str, default="bfloat16") + parser.add_argument("--amp", action="store_true") + args = parser.parse_args() + # do some assumption magic + # to-do: find a nice way to spawn multiple processes where tqdm plays nicely + if args.device.isnumeric(): + args.stride = torch.cuda.device_count() + args.stride_offset = int(args.device) + args.device = f'cuda:{args.device}' + process( audio_backend=args.audio_backend, input_audio=args.input_audio, diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index cd98a0b..3951104 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -252,18 +252,12 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N dac_version='1.0.0', ) dummy = True + elif hasattr( metadata, "__dict__" ): + metadata = metadata.__dict__ + metadata.pop("codes") + # generate object with copied metadata - artifact = DACFile( - codes = codes, - # yes I can **kwargs from a dict but what if I want to pass the actual DACFile.metadata from elsewhere - chunk_length = metadata["chunk_length"] if isinstance(metadata, dict) else metadata.chunk_length, - original_length = metadata["original_length"] if isinstance(metadata, dict) else metadata.original_length, - input_db = metadata["input_db"] if isinstance(metadata, dict) else metadata.input_db, - channels = metadata["channels"] if isinstance(metadata, dict) else metadata.channels, - sample_rate = metadata["sample_rate"] if isinstance(metadata, dict) else metadata.sample_rate, - padding = metadata["padding"] if isinstance(metadata, dict) else metadata.padding, - dac_version = metadata["dac_version"] if isinstance(metadata, dict) else metadata.dac_version, - ) + artifact = DACFile( codes = codes, **metadata ) artifact.dummy = dummy # to-do: inject the sample rate encoded at, because we can actually decouple @@ -368,7 +362,9 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.mod levels = 8 if model.model_type == "24khz" else None with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp): - artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels) + # I guess it's safe to not encode in one chunk + #artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels) + artifact = model.compress(signal, verbose=False, n_quantizers=levels) return artifact.codes if not return_metadata else artifact # AudioDec uses a different pathway diff --git a/vall_e/emb/transcribe.py b/vall_e/emb/transcribe.py index ab9aad7..6b05281 100644 --- a/vall_e/emb/transcribe.py +++ b/vall_e/emb/transcribe.py @@ -17,6 +17,10 @@ from pathlib import Path def pad(num, zeroes): return str(num).zfill(zeroes+1) +def process_items( items, stride=0, stride_offset=0 ): + items = sorted( items ) + return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ] + def transcribe( input_audio = "voices", output_metadata = "training/metadata", @@ -25,6 +29,9 @@ def transcribe( skip_existing = True, diarize = False, + stride = 0, + stride_offset = , + batch_size = 16, device = "cuda", dtype = "float16", @@ -42,7 +49,7 @@ def transcribe( if not os.path.isdir(f'./{input_audio}/{dataset_name}/'): continue - for speaker_id in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/'), desc="Processing speaker"): + for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{dataset_name}/')), desc="Processing speaker"): if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'): continue @@ -55,7 +62,6 @@ def transcribe( metadata = {} for filename in tqdm(os.listdir(f'./{input_audio}/{dataset_name}/{speaker_id}/'), desc=f"Processing speaker: {speaker_id}"): - if skip_existing and filename in metadata: continue @@ -122,6 +128,8 @@ def main(): parser.add_argument("--skip-existing", action="store_true") parser.add_argument("--diarize", action="store_true") parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--stride", type=int, default=0) + parser.add_argument("--stride-offset", type=int, default=0) parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--dtype", type=str, default="bfloat16") @@ -129,6 +137,13 @@ def main(): # parser.add_argument("--raise-exceptions", action="store_true") args = parser.parse_args() + + # do some assumption magic + # to-do: find a nice way to spawn multiple processes where tqdm plays nicely + if args.device.isnumeric(): + args.stride = torch.cuda.device_count() + args.stride_offset = int(args.device) + args.device = f'cuda:{args.device}' transcribe( input_audio = args.input_audio, @@ -138,6 +153,9 @@ def main(): skip_existing = args.skip_existing, diarize = args.diarize, + stride = args.stride, + stride_offset = args.stride_offset, + batch_size = args.batch_size, device = args.device, dtype = args.dtype,