added option to buffer process jobs across multiple speakers to maybe squeeze out some throughput speeds for vall_e.emb.process (in the event of lots of speakers with low file counts, such as Emilia)
This commit is contained in:
parent
ce1ca0124a
commit
fc1ec2019d
|
@ -170,10 +170,10 @@ def process(
|
|||
slice="auto",
|
||||
batch_size=1,
|
||||
max_duration=None,
|
||||
max_samples=None,
|
||||
min_utterances=None,
|
||||
skip_existing_folders=False,
|
||||
low_memory=False,
|
||||
batch_threshold=0,
|
||||
strict_languages=False,
|
||||
|
||||
device="cuda",
|
||||
|
@ -209,6 +209,15 @@ def process(
|
|||
"audio": []
|
||||
}
|
||||
dataset = []
|
||||
jobs = []
|
||||
waveforms = {}
|
||||
|
||||
def check_and_process_jobs(jobs, speaker_id=""):
|
||||
if len(jobs) < batch_threshold:
|
||||
return False
|
||||
|
||||
process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None )
|
||||
return True
|
||||
|
||||
if input_voice is not None:
|
||||
only_speakers = [input_voice]
|
||||
|
@ -276,13 +285,11 @@ def process(
|
|||
dataset.append(f'{group_name}/{speaker_id}')
|
||||
|
||||
|
||||
jobs = []
|
||||
use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or group_name in always_slice_groups
|
||||
if min_utterances and len(metadata.keys()) < min_utterances:
|
||||
continue
|
||||
|
||||
for filename in sorted(metadata.keys()):
|
||||
|
||||
inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}')
|
||||
|
||||
"""
|
||||
|
@ -371,19 +378,8 @@ def process(
|
|||
continue
|
||||
|
||||
jobs.append(( outpath, waveform if presliced else waveform[:, start:end], sample_rate, text, language ))
|
||||
if max_samples and len(jobs) >= max_samples:
|
||||
break
|
||||
if not low_memory and max_samples and len(jobs) >= max_samples:
|
||||
break
|
||||
|
||||
# processes audio files one at a time
|
||||
if low_memory:
|
||||
process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None )
|
||||
jobs = []
|
||||
|
||||
# processes all audio files for a given speaker
|
||||
if not low_memory:
|
||||
process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None )
|
||||
if check_and_process_jobs(jobs, speaker_id=speaker_id):
|
||||
jobs = []
|
||||
|
||||
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
|
||||
|
@ -399,7 +395,7 @@ def main():
|
|||
parser.add_argument("--output-dataset", type=str, default="training/dataset")
|
||||
parser.add_argument("--transcription-filename", type=str, default="whisper.json")
|
||||
parser.add_argument("--raise-exceptions", action="store_true")
|
||||
parser.add_argument("--low-memory", action="store_true")
|
||||
#parser.add_argument("--low-memory", action="store_true")
|
||||
parser.add_argument("--skip-existing-folders", action="store_true")
|
||||
parser.add_argument("--strict-languages", action="store_true")
|
||||
parser.add_argument("--stride", type=int, default=0)
|
||||
|
@ -407,8 +403,8 @@ def main():
|
|||
parser.add_argument("--slice", type=str, default="auto")
|
||||
parser.add_argument("--batch-size", type=int, default=0)
|
||||
parser.add_argument("--max-duration", type=int, default=0)
|
||||
parser.add_argument("--max-samples", type=int, default=0)
|
||||
parser.add_argument("--min-utterances", type=int, default=0)
|
||||
parser.add_argument("--batch-threshold", type=int, default=0)
|
||||
|
||||
parser.add_argument("--device", type=str, default="cuda")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
|
@ -441,12 +437,12 @@ def main():
|
|||
slice=args.slice,
|
||||
batch_size=args.batch_size,
|
||||
max_duration=args.max_duration,
|
||||
max_samples=args.max_samples,
|
||||
min_utterances=args.min_utterances,
|
||||
batch_threshold=args.batch_threshold,
|
||||
skip_existing_folders=args.skip_existing_folders,
|
||||
strict_languages=args.strict_languages,
|
||||
|
||||
low_memory=args.low_memory,
|
||||
#low_memory=args.low_memory,
|
||||
|
||||
device=args.device,
|
||||
dtype=args.dtype,
|
||||
|
|
Loading…
Reference in New Issue
Block a user