From b6ba2cc8e79cc1bf1d478f0b8f9138df04802e24 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 6 Aug 2024 14:24:40 -0500 Subject: [PATCH] tweaked vall_e.emb.process to instead process audio one file at a time instead of all the files for a given speaker to avoid OOMing on less-memory-filled systems with --low-memory --- vall_e/emb/process.py | 168 ++++++++++++++++++++++-------------------- vall_e/emb/qnt.py | 18 +++-- 2 files changed, 100 insertions(+), 86 deletions(-) diff --git a/vall_e/emb/process.py b/vall_e/emb/process.py index 6b01536..5d32379 100644 --- a/vall_e/emb/process.py +++ b/vall_e/emb/process.py @@ -22,10 +22,66 @@ from .qnt import encode as quantize, _replace_file_extension def pad(num, zeroes): return str(num).zfill(zeroes+1) +def load_audio( path, device ): + waveform, sr = torchaudio.load( path ) + if waveform.shape[0] > 1: + # mix channels + waveform = torch.mean(waveform, dim=0, keepdim=True) + return waveform.to(device=device), sr + def process_items( items, stride=0, stride_offset=0 ): items = sorted( items ) return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ] +def process_job( outpath, text, language, waveform, sample_rate ): + phones = phonemize(text, language=language) + qnt = quantize(waveform, sr=sample_rate, device=waveform.device) + + if cfg.audio_backend == "dac": + np.save(open(outpath, "wb"), { + "codes": qnt.codes.cpu().numpy().astype(np.uint16), + "metadata": { + "original_length": qnt.original_length, + "sample_rate": qnt.sample_rate, + + "input_db": qnt.input_db.cpu().numpy().astype(np.float32), + "chunk_length": qnt.chunk_length, + "channels": qnt.channels, + "padding": qnt.padding, + "dac_version": "1.0.0", + + "text": text.strip(), + "phonemes": "".join(phones), + "language": language, + }, + }) + else: + np.save(open(outpath, "wb"), { + "codes": qnt.cpu().numpy().astype(np.uint16), + "metadata": { + "original_length": waveform.shape[-1], + "sample_rate": sample_rate, + + "text": text.strip(), + "phonemes": "".join(phones), + "language": language, + }, + }) + +def process_jobs( jobs, speaker_id="", raise_exceptions=True ): + if not jobs: + return + + for job in tqdm(jobs, desc=f"Quantizing: {speaker_id}"): + outpath, text, language, waveform, sample_rate = job + try: + process_job( outpath, text, language, waveform, sample_rate ) + except Exception as e: + print(f"Failed to quantize: {outpath}:", e) + if raise_exceptions: + raise e + continue + def process( audio_backend="encodec", input_audio="voices", @@ -36,6 +92,8 @@ def process( stride_offset=0, slice="auto", + low_memory=False, + device="cuda", dtype="float16", amp=False, @@ -73,6 +131,7 @@ def process( only_speakers = [] # only process these speakers always_slice_groups = [] # always slice from this group + audio_only = ["Noise"] # special pathway for processing audio only (without a transcription) missing = { "transcription": [], @@ -102,19 +161,20 @@ def process( os.makedirs(f'./{output_dataset}/{group_name}/{speaker_id}/', exist_ok=True) - if speaker_id == "Noise": + if speaker_id in audio_only: for filename in sorted(os.listdir(f'./{input_audio}/{group_name}/{speaker_id}/')): inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}') outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{filename}') + outpath = _replace_file_extension(outpath, audio_extension) - if _replace_file_extension(outpath, audio_extension).exists(): + if outpath.exists(): continue - waveform, sample_rate = torchaudio.load(inpath) + waveform, sample_rate = load_audio( inpath, device ) qnt = quantize(waveform, sr=sample_rate, device=device) if cfg.audio_backend == "dac": - np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { + np.save(open(outpath, "wb"), { "codes": qnt.codes.cpu().numpy().astype(np.uint16), "metadata": { "original_length": qnt.original_length, @@ -128,7 +188,7 @@ def process( }, }) else: - np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { + np.save(open(outpath, "wb"), { "codes": qnt.cpu().numpy().astype(np.uint16), "metadata": { "original_length": waveform.shape[-1], @@ -152,8 +212,7 @@ def process( if f'{group_name}/{speaker_id}' not in dataset: dataset.append(f'{group_name}/{speaker_id}') - txts = [] - wavs = [] + jobs = [] use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or group_name in always_slice_groups @@ -171,26 +230,17 @@ def process( if len(metadata[filename]["segments"]) == 0 or not use_slices: outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}') + outpath = _replace_file_extension(outpath, audio_extension) text = metadata[filename]["text"] - if len(text) == 0: - continue - - if _replace_file_extension(outpath, audio_extension).exists(): + if len(text) == 0 or outpath.exists(): continue + # audio not already loaded, load it if waveform is None: - waveform, sample_rate = torchaudio.load(inpath) - if waveform.shape[0] > 1: - waveform = torch.mean(waveform, dim=0, keepdim=True) + waveform, sample_rate = load_audio( inpath, device ) - wavs.append(( - outpath, - text, - language, - waveform, - sample_rate - )) + jobs.append(( outpath, text, language, waveform, sample_rate )) else: i = 0 for segment in metadata[filename]["segments"]: @@ -198,18 +248,15 @@ def process( i = i + 1 outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}_{id}.{extension}') + outpath = _replace_file_extension(outpath, audio_extension) text = segment["text"] - if len(text) == 0: - continue - - if _replace_file_extension(outpath, audio_extension).exists(): + if len(text) == 0 or outpath.exists(): continue + # audio not already loaded, load it if waveform is None: - waveform, sample_rate = torchaudio.load(inpath) - if waveform.shape[0] > 1: - waveform = torch.mean(waveform, dim=0, keepdim=True) + waveform, sample_rate = load_audio( inpath, device ) start = int(segment['start'] * sample_rate) end = int(segment['end'] * sample_rate) @@ -222,59 +269,17 @@ def process( if end - start < 0: continue - wavs.append(( - outpath, - text, - language, - waveform[:, start:end], - sample_rate - )) + jobs.append(( outpath, text, language, waveform[:, start:end], sample_rate )) - if len(wavs) > 0: - for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"): - try: - outpath, text, language, waveform, sample_rate = job - - phones = phonemize(text, language=language) - qnt = quantize(waveform, sr=sample_rate, device=device) - - - if cfg.audio_backend == "dac": - np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { - "codes": qnt.codes.cpu().numpy().astype(np.uint16), - "metadata": { - "original_length": qnt.original_length, - "sample_rate": qnt.sample_rate, - - "input_db": qnt.input_db.cpu().numpy().astype(np.float32), - "chunk_length": qnt.chunk_length, - "channels": qnt.channels, - "padding": qnt.padding, - "dac_version": "1.0.0", - - "text": text.strip(), - "phonemes": "".join(phones), - "language": language, - }, - }) - else: - np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { - "codes": qnt.cpu().numpy().astype(np.uint16), - "metadata": { - "original_length": waveform.shape[-1], - "sample_rate": sample_rate, - - "text": text.strip(), - "phonemes": "".join(phones), - "language": language, - }, - }) - except Exception as e: - print(f"Failed to quantize: {outpath}:", e) - torchaudio.save( waveform.cpu ) - if raise_exceptions: - raise e - continue + # processes audio files one at a time + if low_memory: + process_jobs( jobs, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions ) + jobs = [] + + # processes all audio files for a given speaker + if not low_memory: + process_jobs( jobs, speaker_id=speaker_id, raise_exceptions=raise_exceptions ) + jobs = [] open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing)) open(f"./{output_dataset}/dataset.json", 'w', encoding='utf-8').write(json.dumps(dataset)) @@ -287,6 +292,7 @@ def main(): parser.add_argument("--input-metadata", type=str, default="training/metadata") parser.add_argument("--output-dataset", type=str, default="training/dataset") parser.add_argument("--raise-exceptions", action="store_true") + parser.add_argument("--low-memory", action="store_true") parser.add_argument("--stride", type=int, default=0) parser.add_argument("--stride-offset", type=int, default=0) parser.add_argument("--slice", type=str, default="auto") @@ -313,6 +319,8 @@ def main(): stride=args.stride, stride_offset=args.stride_offset, slice=args.slice, + + low_memory=args.low_memory, device=args.device, dtype=args.dtype, diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 3951104..389620c 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -254,10 +254,17 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N dummy = True elif hasattr( metadata, "__dict__" ): metadata = metadata.__dict__ - metadata.pop("codes") - # generate object with copied metadata - artifact = DACFile( codes = codes, **metadata ) + artifact = DACFile( + codes = codes, + chunk_length = metadata["chunk_length"], + original_length = metadata["original_length"], + input_db = metadata["input_db"], + channels = metadata["channels"], + sample_rate = metadata["sample_rate"], + padding = metadata["padding"], + dac_version = metadata["dac_version"], + ) artifact.dummy = dummy # to-do: inject the sample rate encoded at, because we can actually decouple @@ -362,9 +369,8 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda", levels=cfg.mod levels = 8 if model.model_type == "24khz" else None with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp): - # I guess it's safe to not encode in one chunk - #artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels) - artifact = model.compress(signal, verbose=False, n_quantizers=levels) + artifact = model.compress(signal, win_duration=None, verbose=False, n_quantizers=levels) + #artifact = model.compress(signal, n_quantizers=levels) return artifact.codes if not return_metadata else artifact # AudioDec uses a different pathway