From 0656a762af77749b9b6a021c26cb78c3e12f0cf4 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 8 Oct 2024 19:24:43 -0500 Subject: [PATCH] fix vall_e.emb.transcriber --- vall_e/emb/process.py | 32 +++++++++++++++++++------------- vall_e/emb/transcribe.py | 27 ++++++++++++++++++++++++++- 2 files changed, 45 insertions(+), 14 deletions(-) diff --git a/vall_e/emb/process.py b/vall_e/emb/process.py index 82723b3..57ecdfb 100644 --- a/vall_e/emb/process.py +++ b/vall_e/emb/process.py @@ -87,21 +87,22 @@ def process_jobs( jobs, speaker_id="", raise_exceptions=True ): continue def process( - audio_backend="encodec", - input_audio="voices", - input_metadata="metadata", - output_dataset="training", - raise_exceptions=False, - stride=0, - stride_offset=0, - slice="auto", + audio_backend="encodec", + input_audio="voices", + input_voice=None, + input_metadata="metadata", + output_dataset="training", + raise_exceptions=False, + stride=0, + stride_offset=0, + slice="auto", - low_memory=False, + low_memory=False, - device="cuda", - dtype="float16", - amp=False, - ): + device="cuda", + dtype="float16", + amp=False, +): # prepare from args cfg.set_audio_backend(audio_backend) audio_extension = cfg.audio_backend_extension @@ -129,6 +130,9 @@ def process( } dataset = [] + if input_voice is not None: + only_speakers = [input_voice] + for group_name in sorted(os.listdir(f'./{input_audio}/')): if not os.path.isdir(f'./{input_audio}/{group_name}/'): _logger.warning(f'Is not dir:" /{input_audio}/{group_name}/') @@ -255,6 +259,7 @@ def main(): parser.add_argument("--audio-backend", type=str, default="encodec") parser.add_argument("--input-audio", type=str, default="voices") + parser.add_argument("--input-voice", type=str, default=None) parser.add_argument("--input-metadata", type=str, default="training/metadata") parser.add_argument("--output-dataset", type=str, default="training/dataset") parser.add_argument("--raise-exceptions", action="store_true") @@ -279,6 +284,7 @@ def main(): process( audio_backend=args.audio_backend, input_audio=args.input_audio, + input_voice=args.input_voice, input_metadata=args.input_metadata, output_dataset=args.output_dataset, raise_exceptions=args.raise_exceptions, diff --git a/vall_e/emb/transcribe.py b/vall_e/emb/transcribe.py index 6b05281..6d66c86 100644 --- a/vall_e/emb/transcribe.py +++ b/vall_e/emb/transcribe.py @@ -23,6 +23,7 @@ def process_items( items, stride=0, stride_offset=0 ): def transcribe( input_audio = "voices", + input_voice = None, output_metadata = "training/metadata", model_name = "large-v3", @@ -30,12 +31,24 @@ def transcribe( diarize = False, stride = 0, - stride_offset = , + stride_offset = 0, batch_size = 16, device = "cuda", dtype = "float16", ): + # to-do: make this also prepared from args + language_map = {} # k = group, v = language + + ignore_groups = [] # skip these groups + ignore_speakers = [] # skip these speakers + + only_groups = [] # only process these groups + only_speakers = [] # only process these speakers + + if input_voice is not None: + only_speakers = [input_voice] + # model = whisperx.load_model(model_name, device, compute_type=dtype) align_model, align_model_metadata, align_model_language = (None, None, None) @@ -49,10 +62,20 @@ def transcribe( if not os.path.isdir(f'./{input_audio}/{dataset_name}/'): continue + if group_name in ignore_groups: + continue + if only_groups and group_name not in only_groups: + continue + for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{dataset_name}/')), desc="Processing speaker"): if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'): continue + if speaker_id in ignore_speakers: + continue + if only_speakers and speaker_id not in only_speakers: + continue + outpath = Path(f'./{output_metadata}/{dataset_name}/{speaker_id}/whisper.json') if outpath.exists(): @@ -122,6 +145,7 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument("--input-audio", type=str, default="voices") + parser.add_argument("--input-voice", type=str, default=None) parser.add_argument("--output-metadata", type=str, default="training/metadata") parser.add_argument("--model-name", type=str, default="large-v3") @@ -147,6 +171,7 @@ def main(): transcribe( input_audio = args.input_audio, + input_voice = args.input_voice, output_metadata = args.output_metadata, model_name = args.model_name,