added option to buffer process jobs across multiple speakers to maybe squeeze out some throughput speeds for vall_e.emb.process (in the event of lots of speakers with low file counts, such as Emilia)

This commit is contained in:
mrq 2025-02-20 14:56:32 -06:00
parent ce1ca0124a
commit fc1ec2019d

View File

@ -170,10 +170,10 @@ def process(
slice="auto", slice="auto",
batch_size=1, batch_size=1,
max_duration=None, max_duration=None,
max_samples=None,
min_utterances=None, min_utterances=None,
skip_existing_folders=False, skip_existing_folders=False,
low_memory=False, low_memory=False,
batch_threshold=0,
strict_languages=False, strict_languages=False,
device="cuda", device="cuda",
@ -209,6 +209,15 @@ def process(
"audio": [] "audio": []
} }
dataset = [] dataset = []
jobs = []
waveforms = {}
def check_and_process_jobs(jobs, speaker_id=""):
if len(jobs) < batch_threshold:
return False
process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None )
return True
if input_voice is not None: if input_voice is not None:
only_speakers = [input_voice] only_speakers = [input_voice]
@ -276,13 +285,11 @@ def process(
dataset.append(f'{group_name}/{speaker_id}') dataset.append(f'{group_name}/{speaker_id}')
jobs = []
use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or group_name in always_slice_groups use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or group_name in always_slice_groups
if min_utterances and len(metadata.keys()) < min_utterances: if min_utterances and len(metadata.keys()) < min_utterances:
continue continue
for filename in sorted(metadata.keys()): for filename in sorted(metadata.keys()):
inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}') inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}')
""" """
@ -371,19 +378,8 @@ def process(
continue continue
jobs.append(( outpath, waveform if presliced else waveform[:, start:end], sample_rate, text, language )) jobs.append(( outpath, waveform if presliced else waveform[:, start:end], sample_rate, text, language ))
if max_samples and len(jobs) >= max_samples:
break
if not low_memory and max_samples and len(jobs) >= max_samples:
break
# processes audio files one at a time if check_and_process_jobs(jobs, speaker_id=speaker_id):
if low_memory:
process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None )
jobs = []
# processes all audio files for a given speaker
if not low_memory:
process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None )
jobs = [] jobs = []
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing)) open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
@ -399,7 +395,7 @@ def main():
parser.add_argument("--output-dataset", type=str, default="training/dataset") parser.add_argument("--output-dataset", type=str, default="training/dataset")
parser.add_argument("--transcription-filename", type=str, default="whisper.json") parser.add_argument("--transcription-filename", type=str, default="whisper.json")
parser.add_argument("--raise-exceptions", action="store_true") parser.add_argument("--raise-exceptions", action="store_true")
parser.add_argument("--low-memory", action="store_true") #parser.add_argument("--low-memory", action="store_true")
parser.add_argument("--skip-existing-folders", action="store_true") parser.add_argument("--skip-existing-folders", action="store_true")
parser.add_argument("--strict-languages", action="store_true") parser.add_argument("--strict-languages", action="store_true")
parser.add_argument("--stride", type=int, default=0) parser.add_argument("--stride", type=int, default=0)
@ -407,8 +403,8 @@ def main():
parser.add_argument("--slice", type=str, default="auto") parser.add_argument("--slice", type=str, default="auto")
parser.add_argument("--batch-size", type=int, default=0) parser.add_argument("--batch-size", type=int, default=0)
parser.add_argument("--max-duration", type=int, default=0) parser.add_argument("--max-duration", type=int, default=0)
parser.add_argument("--max-samples", type=int, default=0)
parser.add_argument("--min-utterances", type=int, default=0) parser.add_argument("--min-utterances", type=int, default=0)
parser.add_argument("--batch-threshold", type=int, default=0)
parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--dtype", type=str, default="bfloat16") parser.add_argument("--dtype", type=str, default="bfloat16")
@ -441,12 +437,12 @@ def main():
slice=args.slice, slice=args.slice,
batch_size=args.batch_size, batch_size=args.batch_size,
max_duration=args.max_duration, max_duration=args.max_duration,
max_samples=args.max_samples,
min_utterances=args.min_utterances, min_utterances=args.min_utterances,
batch_threshold=args.batch_threshold,
skip_existing_folders=args.skip_existing_folders, skip_existing_folders=args.skip_existing_folders,
strict_languages=args.strict_languages, strict_languages=args.strict_languages,
low_memory=args.low_memory, #low_memory=args.low_memory,
device=args.device, device=args.device,
dtype=args.dtype, dtype=args.dtype,