updated vall_e.emb.process to allow for batched processing, some typo fixes (it's painfully slow on my 7900XTX...)

This commit is contained in:
mrq 2025-02-05 21:13:20 -06:00
parent 79c504c278
commit 7592befc53
2 changed files with 81 additions and 7 deletions

View File

@ -817,13 +817,13 @@ class Config(BaseConfig):
cfg.model.resp_levels = 9
elif cfg.audio_backend == "audiodec":
audio_extension = ".dec"
sample_rate = 48_000
cfg.sample_rate = 48_000
cfg.model.resp_levels = 8 # ?
elif cfg.audio_backend == "nemo":
audio_extension = ".nem"
sample_rate = 44_100
cfg.sample_rate = 44_100
cfg.model.resp_levels = 8
cfg.model.audio_tokens = 1000
#cfg.model.audio_tokens = 1000
else:
raise Exception(f"Unknown audio backend: {audio_backend}")

View File

@ -20,7 +20,7 @@ from ..config import cfg
# need to validate if this is safe to import before modifying the config
from .g2p import encode as phonemize
from .qnt import encode as quantize
from .qnt import encode as quantize, encode_batch as quantize_batch
def pad(num, zeroes):
return str(num).zfill(zeroes+1)
@ -75,9 +75,80 @@ def process_job( outpath, waveform, sample_rate, text=None, language="en", devic
np.save(open(outpath, "wb"), state_dict)
def process_jobs( jobs, speaker_id="", device=None, raise_exceptions=True ):
def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=True, batch_size=1 ):
if not jobs:
return
buffer = []
batches = []
for job in jobs:
buffer.append(job)
if len(buffer) >= batch_size:
batches.append(buffer)
buffer = []
if len(buffer) >= batch_size:
batches.append(buffer)
buffer = []
for batch in tqdm(batches, desc=f'Quantizing {speaker_id} (batch size: {batch_size})'):
wavs = []
srs = []
for outpath, waveform, sample_rate, text, language in batch:
wavs.append(waveform)
srs.append(sample_rate)
try:
codes = quantize_batch(wavs, sr=srs, device=device)
except Exception as e:
_logger.error(f"Failed to quantize: {outpath}: {str(e)}")
if raise_exceptions:
raise e
continue
for (outpath, waveform, sample_rate, text, language), qnt in zip( batch, codes ):
if cfg.audio_backend == "dac":
state_dict = {
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.original_length,
"sample_rate": qnt.sample_rate,
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
"chunk_length": qnt.chunk_length,
"channels": qnt.channels,
"padding": qnt.padding,
"dac_version": "1.0.0",
},
}
else:
state_dict = {
"codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": waveform.shape[-1],
"sample_rate": sample_rate,
},
}
if text:
text = text.strip()
state_dict['metadata'] |= {
"text": text,
"phonemes": phonemize(text, language=language),
"language": language,
}
np.save(open(outpath, "wb"), state_dict)
def process_jobs( jobs, speaker_id="", device=None, raise_exceptions=True, batch_size=1 ):
if not jobs:
return
# batch things
if batch_size > 1:
return process_batched_jobs( jobs, speaker_id=speaker_id, device=device, raise_exceptions=raise_exceptions, batch_size=batch_size )
for job in tqdm(jobs, desc=f"Quantizing: {speaker_id}"):
outpath, waveform, sample_rate, text, language = job
@ -99,6 +170,7 @@ def process(
stride=0,
stride_offset=0,
slice="auto",
batch_size=1,
low_memory=False,
@ -262,12 +334,12 @@ def process(
# processes audio files one at a time
if low_memory:
process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions )
process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size )
jobs = []
# processes all audio files for a given speaker
if not low_memory:
process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions )
process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions, batch_size=batch_size )
jobs = []
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
@ -286,6 +358,7 @@ def main():
parser.add_argument("--stride", type=int, default=0)
parser.add_argument("--stride-offset", type=int, default=0)
parser.add_argument("--slice", type=str, default="auto")
parser.add_argument("--batch-size", type=int, default=0)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--dtype", type=str, default="bfloat16")
@ -315,6 +388,7 @@ def main():
stride=args.stride,
stride_offset=args.stride_offset,
slice=args.slice,
batch_size=args.batch_size,
low_memory=args.low_memory,