From 7592befc5312f769f566c631ea4236b2b0960116 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 5 Feb 2025 21:13:20 -0600 Subject: [PATCH] updated vall_e.emb.process to allow for batched processing, some typo fixes (it's painfully slow on my 7900XTX...) --- vall_e/config.py | 6 ++-- vall_e/emb/process.py | 82 ++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 81 insertions(+), 7 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 8b9b087..9b92fbe 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -817,13 +817,13 @@ class Config(BaseConfig): cfg.model.resp_levels = 9 elif cfg.audio_backend == "audiodec": audio_extension = ".dec" - sample_rate = 48_000 + cfg.sample_rate = 48_000 cfg.model.resp_levels = 8 # ? elif cfg.audio_backend == "nemo": audio_extension = ".nem" - sample_rate = 44_100 + cfg.sample_rate = 44_100 cfg.model.resp_levels = 8 - cfg.model.audio_tokens = 1000 + #cfg.model.audio_tokens = 1000 else: raise Exception(f"Unknown audio backend: {audio_backend}") diff --git a/vall_e/emb/process.py b/vall_e/emb/process.py index dee037f..035b697 100644 --- a/vall_e/emb/process.py +++ b/vall_e/emb/process.py @@ -20,7 +20,7 @@ from ..config import cfg # need to validate if this is safe to import before modifying the config from .g2p import encode as phonemize -from .qnt import encode as quantize +from .qnt import encode as quantize, encode_batch as quantize_batch def pad(num, zeroes): return str(num).zfill(zeroes+1) @@ -75,9 +75,80 @@ def process_job( outpath, waveform, sample_rate, text=None, language="en", devic np.save(open(outpath, "wb"), state_dict) -def process_jobs( jobs, speaker_id="", device=None, raise_exceptions=True ): +def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=True, batch_size=1 ): if not jobs: return + + buffer = [] + batches = [] + + for job in jobs: + buffer.append(job) + if len(buffer) >= batch_size: + batches.append(buffer) + buffer = [] + + if len(buffer) >= batch_size: + batches.append(buffer) + buffer = [] + + for batch in tqdm(batches, desc=f'Quantizing {speaker_id} (batch size: {batch_size})'): + wavs = [] + srs = [] + + for outpath, waveform, sample_rate, text, language in batch: + wavs.append(waveform) + srs.append(sample_rate) + + try: + codes = quantize_batch(wavs, sr=srs, device=device) + except Exception as e: + _logger.error(f"Failed to quantize: {outpath}: {str(e)}") + if raise_exceptions: + raise e + continue + + for (outpath, waveform, sample_rate, text, language), qnt in zip( batch, codes ): + if cfg.audio_backend == "dac": + state_dict = { + "codes": qnt.codes.cpu().numpy().astype(np.uint16), + "metadata": { + "original_length": qnt.original_length, + "sample_rate": qnt.sample_rate, + + "input_db": qnt.input_db.cpu().numpy().astype(np.float32), + "chunk_length": qnt.chunk_length, + "channels": qnt.channels, + "padding": qnt.padding, + "dac_version": "1.0.0", + }, + } + else: + state_dict = { + "codes": qnt.cpu().numpy().astype(np.uint16), + "metadata": { + "original_length": waveform.shape[-1], + "sample_rate": sample_rate, + }, + } + + if text: + text = text.strip() + state_dict['metadata'] |= { + "text": text, + "phonemes": phonemize(text, language=language), + "language": language, + } + + np.save(open(outpath, "wb"), state_dict) + +def process_jobs( jobs, speaker_id="", device=None, raise_exceptions=True, batch_size=1 ): + if not jobs: + return + + # batch things + if batch_size > 1: + return process_batched_jobs( jobs, speaker_id=speaker_id, device=device, raise_exceptions=raise_exceptions, batch_size=batch_size ) for job in tqdm(jobs, desc=f"Quantizing: {speaker_id}"): outpath, waveform, sample_rate, text, language = job @@ -99,6 +170,7 @@ def process( stride=0, stride_offset=0, slice="auto", + batch_size=1, low_memory=False, @@ -262,12 +334,12 @@ def process( # processes audio files one at a time if low_memory: - process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions ) + process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size ) jobs = [] # processes all audio files for a given speaker if not low_memory: - process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions ) + process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions, batch_size=batch_size ) jobs = [] open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing)) @@ -286,6 +358,7 @@ def main(): parser.add_argument("--stride", type=int, default=0) parser.add_argument("--stride-offset", type=int, default=0) parser.add_argument("--slice", type=str, default="auto") + parser.add_argument("--batch-size", type=int, default=0) parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--dtype", type=str, default="bfloat16") @@ -315,6 +388,7 @@ def main(): stride=args.stride, stride_offset=args.stride_offset, slice=args.slice, + batch_size=args.batch_size, low_memory=args.low_memory,