updated vall_e.emb.process to allow for batched processing, some typo fixes (it's painfully slow on my 7900XTX...)
This commit is contained in:
parent
79c504c278
commit
7592befc53
|
@ -817,13 +817,13 @@ class Config(BaseConfig):
|
||||||
cfg.model.resp_levels = 9
|
cfg.model.resp_levels = 9
|
||||||
elif cfg.audio_backend == "audiodec":
|
elif cfg.audio_backend == "audiodec":
|
||||||
audio_extension = ".dec"
|
audio_extension = ".dec"
|
||||||
sample_rate = 48_000
|
cfg.sample_rate = 48_000
|
||||||
cfg.model.resp_levels = 8 # ?
|
cfg.model.resp_levels = 8 # ?
|
||||||
elif cfg.audio_backend == "nemo":
|
elif cfg.audio_backend == "nemo":
|
||||||
audio_extension = ".nem"
|
audio_extension = ".nem"
|
||||||
sample_rate = 44_100
|
cfg.sample_rate = 44_100
|
||||||
cfg.model.resp_levels = 8
|
cfg.model.resp_levels = 8
|
||||||
cfg.model.audio_tokens = 1000
|
#cfg.model.audio_tokens = 1000
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unknown audio backend: {audio_backend}")
|
raise Exception(f"Unknown audio backend: {audio_backend}")
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ from ..config import cfg
|
||||||
|
|
||||||
# need to validate if this is safe to import before modifying the config
|
# need to validate if this is safe to import before modifying the config
|
||||||
from .g2p import encode as phonemize
|
from .g2p import encode as phonemize
|
||||||
from .qnt import encode as quantize
|
from .qnt import encode as quantize, encode_batch as quantize_batch
|
||||||
|
|
||||||
def pad(num, zeroes):
|
def pad(num, zeroes):
|
||||||
return str(num).zfill(zeroes+1)
|
return str(num).zfill(zeroes+1)
|
||||||
|
@ -75,10 +75,81 @@ def process_job( outpath, waveform, sample_rate, text=None, language="en", devic
|
||||||
|
|
||||||
np.save(open(outpath, "wb"), state_dict)
|
np.save(open(outpath, "wb"), state_dict)
|
||||||
|
|
||||||
def process_jobs( jobs, speaker_id="", device=None, raise_exceptions=True ):
|
def process_batched_jobs( jobs, speaker_id="", device=None, raise_exceptions=True, batch_size=1 ):
|
||||||
if not jobs:
|
if not jobs:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
buffer = []
|
||||||
|
batches = []
|
||||||
|
|
||||||
|
for job in jobs:
|
||||||
|
buffer.append(job)
|
||||||
|
if len(buffer) >= batch_size:
|
||||||
|
batches.append(buffer)
|
||||||
|
buffer = []
|
||||||
|
|
||||||
|
if len(buffer) >= batch_size:
|
||||||
|
batches.append(buffer)
|
||||||
|
buffer = []
|
||||||
|
|
||||||
|
for batch in tqdm(batches, desc=f'Quantizing {speaker_id} (batch size: {batch_size})'):
|
||||||
|
wavs = []
|
||||||
|
srs = []
|
||||||
|
|
||||||
|
for outpath, waveform, sample_rate, text, language in batch:
|
||||||
|
wavs.append(waveform)
|
||||||
|
srs.append(sample_rate)
|
||||||
|
|
||||||
|
try:
|
||||||
|
codes = quantize_batch(wavs, sr=srs, device=device)
|
||||||
|
except Exception as e:
|
||||||
|
_logger.error(f"Failed to quantize: {outpath}: {str(e)}")
|
||||||
|
if raise_exceptions:
|
||||||
|
raise e
|
||||||
|
continue
|
||||||
|
|
||||||
|
for (outpath, waveform, sample_rate, text, language), qnt in zip( batch, codes ):
|
||||||
|
if cfg.audio_backend == "dac":
|
||||||
|
state_dict = {
|
||||||
|
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||||
|
"metadata": {
|
||||||
|
"original_length": qnt.original_length,
|
||||||
|
"sample_rate": qnt.sample_rate,
|
||||||
|
|
||||||
|
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
|
||||||
|
"chunk_length": qnt.chunk_length,
|
||||||
|
"channels": qnt.channels,
|
||||||
|
"padding": qnt.padding,
|
||||||
|
"dac_version": "1.0.0",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
state_dict = {
|
||||||
|
"codes": qnt.cpu().numpy().astype(np.uint16),
|
||||||
|
"metadata": {
|
||||||
|
"original_length": waveform.shape[-1],
|
||||||
|
"sample_rate": sample_rate,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if text:
|
||||||
|
text = text.strip()
|
||||||
|
state_dict['metadata'] |= {
|
||||||
|
"text": text,
|
||||||
|
"phonemes": phonemize(text, language=language),
|
||||||
|
"language": language,
|
||||||
|
}
|
||||||
|
|
||||||
|
np.save(open(outpath, "wb"), state_dict)
|
||||||
|
|
||||||
|
def process_jobs( jobs, speaker_id="", device=None, raise_exceptions=True, batch_size=1 ):
|
||||||
|
if not jobs:
|
||||||
|
return
|
||||||
|
|
||||||
|
# batch things
|
||||||
|
if batch_size > 1:
|
||||||
|
return process_batched_jobs( jobs, speaker_id=speaker_id, device=device, raise_exceptions=raise_exceptions, batch_size=batch_size )
|
||||||
|
|
||||||
for job in tqdm(jobs, desc=f"Quantizing: {speaker_id}"):
|
for job in tqdm(jobs, desc=f"Quantizing: {speaker_id}"):
|
||||||
outpath, waveform, sample_rate, text, language = job
|
outpath, waveform, sample_rate, text, language = job
|
||||||
try:
|
try:
|
||||||
|
@ -99,6 +170,7 @@ def process(
|
||||||
stride=0,
|
stride=0,
|
||||||
stride_offset=0,
|
stride_offset=0,
|
||||||
slice="auto",
|
slice="auto",
|
||||||
|
batch_size=1,
|
||||||
|
|
||||||
low_memory=False,
|
low_memory=False,
|
||||||
|
|
||||||
|
@ -262,12 +334,12 @@ def process(
|
||||||
|
|
||||||
# processes audio files one at a time
|
# processes audio files one at a time
|
||||||
if low_memory:
|
if low_memory:
|
||||||
process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions )
|
process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size )
|
||||||
jobs = []
|
jobs = []
|
||||||
|
|
||||||
# processes all audio files for a given speaker
|
# processes all audio files for a given speaker
|
||||||
if not low_memory:
|
if not low_memory:
|
||||||
process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions )
|
process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions, batch_size=batch_size )
|
||||||
jobs = []
|
jobs = []
|
||||||
|
|
||||||
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
|
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
|
||||||
|
@ -286,6 +358,7 @@ def main():
|
||||||
parser.add_argument("--stride", type=int, default=0)
|
parser.add_argument("--stride", type=int, default=0)
|
||||||
parser.add_argument("--stride-offset", type=int, default=0)
|
parser.add_argument("--stride-offset", type=int, default=0)
|
||||||
parser.add_argument("--slice", type=str, default="auto")
|
parser.add_argument("--slice", type=str, default="auto")
|
||||||
|
parser.add_argument("--batch-size", type=int, default=0)
|
||||||
|
|
||||||
parser.add_argument("--device", type=str, default="cuda")
|
parser.add_argument("--device", type=str, default="cuda")
|
||||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||||
|
@ -315,6 +388,7 @@ def main():
|
||||||
stride=args.stride,
|
stride=args.stride,
|
||||||
stride_offset=args.stride_offset,
|
stride_offset=args.stride_offset,
|
||||||
slice=args.slice,
|
slice=args.slice,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
|
||||||
low_memory=args.low_memory,
|
low_memory=args.low_memory,
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user