diff --git a/vall_e/emb/process.py b/vall_e/emb/process.py index 16ade46..247bd2c 100644 --- a/vall_e/emb/process.py +++ b/vall_e/emb/process.py @@ -170,10 +170,10 @@ def process( slice="auto", batch_size=1, max_duration=None, - max_samples=None, min_utterances=None, skip_existing_folders=False, low_memory=False, + batch_threshold=0, strict_languages=False, device="cuda", @@ -209,6 +209,15 @@ def process( "audio": [] } dataset = [] + jobs = [] + waveforms = {} + + def check_and_process_jobs(jobs, speaker_id=""): + if len(jobs) < batch_threshold: + return False + + process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None ) + return True if input_voice is not None: only_speakers = [input_voice] @@ -276,13 +285,11 @@ def process( dataset.append(f'{group_name}/{speaker_id}') - jobs = [] use_slices = slice == True or (slice == "auto" and len(metadata.keys()) == 1) or group_name in always_slice_groups if min_utterances and len(metadata.keys()) < min_utterances: continue for filename in sorted(metadata.keys()): - inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{filename}') """ @@ -371,19 +378,8 @@ def process( continue jobs.append(( outpath, waveform if presliced else waveform[:, start:end], sample_rate, text, language )) - if max_samples and len(jobs) >= max_samples: - break - if not low_memory and max_samples and len(jobs) >= max_samples: - break - # processes audio files one at a time - if low_memory: - process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None ) - jobs = [] - - # processes all audio files for a given speaker - if not low_memory: - process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None ) + if check_and_process_jobs(jobs, speaker_id=speaker_id): jobs = [] open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing)) @@ -399,7 +395,7 @@ def main(): parser.add_argument("--output-dataset", type=str, default="training/dataset") parser.add_argument("--transcription-filename", type=str, default="whisper.json") parser.add_argument("--raise-exceptions", action="store_true") - parser.add_argument("--low-memory", action="store_true") + #parser.add_argument("--low-memory", action="store_true") parser.add_argument("--skip-existing-folders", action="store_true") parser.add_argument("--strict-languages", action="store_true") parser.add_argument("--stride", type=int, default=0) @@ -407,8 +403,8 @@ def main(): parser.add_argument("--slice", type=str, default="auto") parser.add_argument("--batch-size", type=int, default=0) parser.add_argument("--max-duration", type=int, default=0) - parser.add_argument("--max-samples", type=int, default=0) parser.add_argument("--min-utterances", type=int, default=0) + parser.add_argument("--batch-threshold", type=int, default=0) parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--dtype", type=str, default="bfloat16") @@ -441,12 +437,12 @@ def main(): slice=args.slice, batch_size=args.batch_size, max_duration=args.max_duration, - max_samples=args.max_samples, min_utterances=args.min_utterances, + batch_threshold=args.batch_threshold, skip_existing_folders=args.skip_existing_folders, strict_languages=args.strict_languages, - low_memory=args.low_memory, + #low_memory=args.low_memory, device=args.device, dtype=args.dtype,