added option to buffer process jobs across multiple speakers to maybe squeeze out some throughput speeds for vall_e.emb.process (in the event of lots of speakers with low file counts, such as Emilia)
This commit is contained in:
parent
ce1ca0124a
commit
fc1ec2019d
|
@ -170,10 +170,10 @@ def process(
|
||||||
slice="auto",
|
slice="auto",
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
max_duration=None,
|
max_duration=None,
|
||||||
max_samples=None,
|
|
||||||
min_utterances=None,
|
min_utterances=None,
|
||||||
skip_existing_folders=False,
|
skip_existing_folders=False,
|
||||||
low_memory=False,
|
low_memory=False,
|
||||||
|
batch_threshold=0,
|
||||||
strict_languages=False,
|
strict_languages=False,
|
||||||
|
|
||||||
device="cuda",
|
device="cuda",
|
||||||
|
@ -209,6 +209,15 @@ def process(
|
||||||
"audio": []
|
"audio": []
|
||||||
}
|
}
|
||||||
dataset = []
|
dataset = []
|
||||||
|
jobs = []
|
||||||
|
waveforms = {}
|
||||||
|
|
||||||
|
def check_and_process_jobs(jobs, speaker_id=""):
|
||||||
|
if len(jobs) < batch_threshold:
|
||||||
|
return False
|
||||||
|
|
||||||
|
process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None )
|
||||||
|
return True
|
||||||
|
|
||||||
if input_voice is not None:
|
if input_voice is not None:
|
||||||
only_speakers = [input_voice]
|
only_speakers = [input_voice]
|
||||||
|
@ -276,13 +285,11 @@ def process(
|
||||||
dataset.append(f'{group_name}/{speaker_id}')
|
dataset.append(f'{group_name}/{speaker_id}')
|
||||||
|
|
||||||
|
|
||||||
jobs = []
|
|
||||||
use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or group_name in always_slice_groups
|
use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or group_name in always_slice_groups
|
||||||
if min_utterances and len(metadata.keys()) < min_utterances:
|
if min_utterances and len(metadata.keys()) < min_utterances:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for filename in sorted(metadata.keys()):
|
for filename in sorted(metadata.keys()):
|
||||||
|
|
||||||
inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}')
|
inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}')
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -371,19 +378,8 @@ def process(
|
||||||
continue
|
continue
|
||||||
|
|
||||||
jobs.append(( outpath, waveform if presliced else waveform[:, start:end], sample_rate, text, language ))
|
jobs.append(( outpath, waveform if presliced else waveform[:, start:end], sample_rate, text, language ))
|
||||||
if max_samples and len(jobs) >= max_samples:
|
|
||||||
break
|
|
||||||
if not low_memory and max_samples and len(jobs) >= max_samples:
|
|
||||||
break
|
|
||||||
|
|
||||||
# processes audio files one at a time
|
if check_and_process_jobs(jobs, speaker_id=speaker_id):
|
||||||
if low_memory:
|
|
||||||
process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None )
|
|
||||||
jobs = []
|
|
||||||
|
|
||||||
# processes all audio files for a given speaker
|
|
||||||
if not low_memory:
|
|
||||||
process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None )
|
|
||||||
jobs = []
|
jobs = []
|
||||||
|
|
||||||
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
|
open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
|
||||||
|
@ -399,7 +395,7 @@ def main():
|
||||||
parser.add_argument("--output-dataset", type=str, default="training/dataset")
|
parser.add_argument("--output-dataset", type=str, default="training/dataset")
|
||||||
parser.add_argument("--transcription-filename", type=str, default="whisper.json")
|
parser.add_argument("--transcription-filename", type=str, default="whisper.json")
|
||||||
parser.add_argument("--raise-exceptions", action="store_true")
|
parser.add_argument("--raise-exceptions", action="store_true")
|
||||||
parser.add_argument("--low-memory", action="store_true")
|
#parser.add_argument("--low-memory", action="store_true")
|
||||||
parser.add_argument("--skip-existing-folders", action="store_true")
|
parser.add_argument("--skip-existing-folders", action="store_true")
|
||||||
parser.add_argument("--strict-languages", action="store_true")
|
parser.add_argument("--strict-languages", action="store_true")
|
||||||
parser.add_argument("--stride", type=int, default=0)
|
parser.add_argument("--stride", type=int, default=0)
|
||||||
|
@ -407,8 +403,8 @@ def main():
|
||||||
parser.add_argument("--slice", type=str, default="auto")
|
parser.add_argument("--slice", type=str, default="auto")
|
||||||
parser.add_argument("--batch-size", type=int, default=0)
|
parser.add_argument("--batch-size", type=int, default=0)
|
||||||
parser.add_argument("--max-duration", type=int, default=0)
|
parser.add_argument("--max-duration", type=int, default=0)
|
||||||
parser.add_argument("--max-samples", type=int, default=0)
|
|
||||||
parser.add_argument("--min-utterances", type=int, default=0)
|
parser.add_argument("--min-utterances", type=int, default=0)
|
||||||
|
parser.add_argument("--batch-threshold", type=int, default=0)
|
||||||
|
|
||||||
parser.add_argument("--device", type=str, default="cuda")
|
parser.add_argument("--device", type=str, default="cuda")
|
||||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||||
|
@ -441,12 +437,12 @@ def main():
|
||||||
slice=args.slice,
|
slice=args.slice,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
max_duration=args.max_duration,
|
max_duration=args.max_duration,
|
||||||
max_samples=args.max_samples,
|
|
||||||
min_utterances=args.min_utterances,
|
min_utterances=args.min_utterances,
|
||||||
|
batch_threshold=args.batch_threshold,
|
||||||
skip_existing_folders=args.skip_existing_folders,
|
skip_existing_folders=args.skip_existing_folders,
|
||||||
strict_languages=args.strict_languages,
|
strict_languages=args.strict_languages,
|
||||||
|
|
||||||
low_memory=args.low_memory,
|
#low_memory=args.low_memory,
|
||||||
|
|
||||||
device=args.device,
|
device=args.device,
|
||||||
dtype=args.dtype,
|
dtype=args.dtype,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user