updated vall_e.emb.process to allow for batched processing, some typo fixes (it's painfully slow on my 7900XTX...)
This commit is contained in:
parent
79c504c278
commit
7592befc53
|
@ -817,13 +817,13 @@ class Config(BaseConfig):
|
|||
cfg.model.resp_levels = 9
|
||||
elif cfg.audio_backend == "audiodec":
|
||||
audio_extension = ".dec"
|
||||
sample_rate = 48_000
|
||||
cfg.sample_rate = 48_000
|
||||
cfg.model.resp_levels = 8 # ?
|
||||
elif cfg.audio_backend == "nemo":
|
||||
audio_extension = ".nem"
|
||||
sample_rate = 44_100
|
||||
cfg.sample_rate = 44_100
|
||||
cfg.model.resp_levels = 8
|
||||
cfg.model.audio_tokens = 1000
|
||||
#cfg.model.audio_tokens = 1000
|
||||
else:
|
||||
raise Exception(f"Unknown audio backend: {audio_backend}")
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ from ..config import cfg
|
|||
|
||||
# need to validate if this is safe to import before modifying the config
|
||||
from .g2p import encode as phonemize
|
||||
from .qnt import encode as quantize
|
||||
from .qnt import encode as quantize, encode_batch as quantize_batch
|
||||
|
||||
def pad(num, zeroes):
|
||||
return str(num).zfill(zeroes+1)
|
||||
|
@ -75,9 +75,80 @@ def process_job( outpath, waveform, sample_rate, text=None, language="en", devic
|
|||
|
||||
np.save(open(outpath, "wb"), state_dict)
|
||||
|
||||
def process_jobs( jobs, speaker_id="", device=None, raise_exceptions=True ):
|
||||
def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=True, batch_size=1 ):
|
||||
if not jobs:
|
||||
return
|
||||
|
||||
buffer = []
|
||||
batches = []
|
||||
|
||||
for job in jobs:
|
||||
buffer.append(job)
|
||||
if len(buffer) >= batch_size:
|
||||
batches.append(buffer)
|
||||
buffer = []
|
||||
|
||||
if len(buffer) >= batch_size:
|
||||
batches.append(buffer)
|
||||
buffer = []
|
||||
|
||||
for batch in tqdm(batches, desc=f'Quantizing {speaker_id} (batch size: {batch_size})'):
|
||||
wavs = []
|
||||
srs = []
|
||||
|
||||
for outpath, waveform, sample_rate, text, language in batch:
|
||||
wavs.append(waveform)
|
||||
srs.append(sample_rate)
|
||||
|
||||
try:
|
||||
codes = quantize_batch(wavs, sr=srs, device=device)
|
||||
except Exception as e:
|
||||
_logger.error(f"Failed to quantize: {outpath}: {str(e)}")
|
||||
if raise_exceptions:
|
||||
raise e
|
||||
continue
|
||||
|
||||
for (outpath, waveform, sample_rate, text, language), qnt in zip( batch, codes ):
|
||||
if cfg.audio_backend == "dac":
|
||||
state_dict = {
|
||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": qnt.original_length,
|
||||
"sample_rate": qnt.sample_rate,
|
||||
|
||||
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
|
||||
"chunk_length": qnt.chunk_length,
|
||||
"channels": qnt.channels,
|
||||
"padding": qnt.padding,
|
||||
"dac_version": "1.0.0",
|
||||
},
|
||||
}
|
||||
else:
|
||||
state_dict = {
|
||||
"codes": qnt.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": waveform.shape[-1],
|
||||
"sample_rate": sample_rate,
|
||||
},
|
||||
}
|
||||
|
||||
if text:
|
||||
text = text.strip()
|
||||
state_dict['metadata'] |= {
|
||||
"text": text,
|
||||
"phonemes": phonemize(text, language=language),
|
||||
"language": language,
|
||||
}
|
||||
|
||||
np.save(open(outpath, "wb"), state_dict)
|
||||
|
||||
def process_jobs( jobs, speaker_id="", device=None, raise_exceptions=True, batch_size=1 ):
|
||||
if not jobs:
|
||||
return
|
||||
|
||||
# batch things
|
||||
if batch_size > 1:
|
||||
return process_batched_jobs( jobs, speaker_id=speaker_id, device=device, raise_exceptions=raise_exceptions, batch_size=batch_size )
|
||||
|
||||
for job in tqdm(jobs, desc=f"Quantizing: {speaker_id}"):
|
||||
outpath, waveform, sample_rate, text, language = job
|
||||
|
@ -99,6 +170,7 @@ def process(
|
|||
stride=0,
|
||||
stride_offset=0,
|
||||
slice="auto",
|
||||
batch_size=1,
|
||||
|
||||
low_memory=False,
|
||||
|
||||
|
@ -262,12 +334,12 @@ def process(
|
|||
|
||||
# processes audio files one at a time
|
||||
if low_memory:
|
||||
process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions )
|
||||
process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size )
|
||||
jobs = []
|
||||
|
||||
# processes all audio files for a given speaker
|
||||
if not low_memory:
|
||||
process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions )
|
||||
process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions, batch_size=batch_size )
|
||||
jobs = []
|
||||
|
||||
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
|
||||
|
@ -286,6 +358,7 @@ def main():
|
|||
parser.add_argument("--stride", type=int, default=0)
|
||||
parser.add_argument("--stride-offset", type=int, default=0)
|
||||
parser.add_argument("--slice", type=str, default="auto")
|
||||
parser.add_argument("--batch-size", type=int, default=0)
|
||||
|
||||
parser.add_argument("--device", type=str, default="cuda")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
|
@ -315,6 +388,7 @@ def main():
|
|||
stride=args.stride,
|
||||
stride_offset=args.stride_offset,
|
||||
slice=args.slice,
|
||||
batch_size=args.batch_size,
|
||||
|
||||
low_memory=args.low_memory,
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user