fix vall_e.emb.transcriber

This commit is contained in:
mrq 2024-10-08 19:24:43 -05:00
parent acdce66d4e
commit 0656a762af
2 changed files with 45 additions and 14 deletions

View File

@ -87,21 +87,22 @@ def process_jobs( jobs, speaker_id="", raise_exceptions=True ):
continue continue
def process( def process(
audio_backend="encodec", audio_backend="encodec",
input_audio="voices", input_audio="voices",
input_metadata="metadata", input_voice=None,
output_dataset="training", input_metadata="metadata",
raise_exceptions=False, output_dataset="training",
stride=0, raise_exceptions=False,
stride_offset=0, stride=0,
slice="auto", stride_offset=0,
slice="auto",
low_memory=False, low_memory=False,
device="cuda", device="cuda",
dtype="float16", dtype="float16",
amp=False, amp=False,
): ):
# prepare from args # prepare from args
cfg.set_audio_backend(audio_backend) cfg.set_audio_backend(audio_backend)
audio_extension = cfg.audio_backend_extension audio_extension = cfg.audio_backend_extension
@ -129,6 +130,9 @@ def process(
} }
dataset = [] dataset = []
if input_voice is not None:
only_speakers = [input_voice]
for group_name in sorted(os.listdir(f'./{input_audio}/')): for group_name in sorted(os.listdir(f'./{input_audio}/')):
if not os.path.isdir(f'./{input_audio}/{group_name}/'): if not os.path.isdir(f'./{input_audio}/{group_name}/'):
_logger.warning(f'Is not dir:" /{input_audio}/{group_name}/') _logger.warning(f'Is not dir:" /{input_audio}/{group_name}/')
@ -255,6 +259,7 @@ def main():
parser.add_argument("--audio-backend", type=str, default="encodec") parser.add_argument("--audio-backend", type=str, default="encodec")
parser.add_argument("--input-audio", type=str, default="voices") parser.add_argument("--input-audio", type=str, default="voices")
parser.add_argument("--input-voice", type=str, default=None)
parser.add_argument("--input-metadata", type=str, default="training/metadata") parser.add_argument("--input-metadata", type=str, default="training/metadata")
parser.add_argument("--output-dataset", type=str, default="training/dataset") parser.add_argument("--output-dataset", type=str, default="training/dataset")
parser.add_argument("--raise-exceptions", action="store_true") parser.add_argument("--raise-exceptions", action="store_true")
@ -279,6 +284,7 @@ def main():
process( process(
audio_backend=args.audio_backend, audio_backend=args.audio_backend,
input_audio=args.input_audio, input_audio=args.input_audio,
input_voice=args.input_voice,
input_metadata=args.input_metadata, input_metadata=args.input_metadata,
output_dataset=args.output_dataset, output_dataset=args.output_dataset,
raise_exceptions=args.raise_exceptions, raise_exceptions=args.raise_exceptions,

View File

@ -23,6 +23,7 @@ def process_items( items, stride=0, stride_offset=0 ):
def transcribe( def transcribe(
input_audio = "voices", input_audio = "voices",
input_voice = None,
output_metadata = "training/metadata", output_metadata = "training/metadata",
model_name = "large-v3", model_name = "large-v3",
@ -30,12 +31,24 @@ def transcribe(
diarize = False, diarize = False,
stride = 0, stride = 0,
stride_offset = , stride_offset = 0,
batch_size = 16, batch_size = 16,
device = "cuda", device = "cuda",
dtype = "float16", dtype = "float16",
): ):
# to-do: make this also prepared from args
language_map = {} # k = group, v = language
ignore_groups = [] # skip these groups
ignore_speakers = [] # skip these speakers
only_groups = [] # only process these groups
only_speakers = [] # only process these speakers
if input_voice is not None:
only_speakers = [input_voice]
# #
model = whisperx.load_model(model_name, device, compute_type=dtype) model = whisperx.load_model(model_name, device, compute_type=dtype)
align_model, align_model_metadata, align_model_language = (None, None, None) align_model, align_model_metadata, align_model_language = (None, None, None)
@ -49,10 +62,20 @@ def transcribe(
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'): if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
continue continue
if group_name in ignore_groups:
continue
if only_groups and group_name not in only_groups:
continue
for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{dataset_name}/')), desc="Processing speaker"): for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{dataset_name}/')), desc="Processing speaker"):
if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'): if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'):
continue continue
if speaker_id in ignore_speakers:
continue
if only_speakers and speaker_id not in only_speakers:
continue
outpath = Path(f'./{output_metadata}/{dataset_name}/{speaker_id}/whisper.json') outpath = Path(f'./{output_metadata}/{dataset_name}/{speaker_id}/whisper.json')
if outpath.exists(): if outpath.exists():
@ -122,6 +145,7 @@ def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--input-audio", type=str, default="voices") parser.add_argument("--input-audio", type=str, default="voices")
parser.add_argument("--input-voice", type=str, default=None)
parser.add_argument("--output-metadata", type=str, default="training/metadata") parser.add_argument("--output-metadata", type=str, default="training/metadata")
parser.add_argument("--model-name", type=str, default="large-v3") parser.add_argument("--model-name", type=str, default="large-v3")
@ -147,6 +171,7 @@ def main():
transcribe( transcribe(
input_audio = args.input_audio, input_audio = args.input_audio,
input_voice = args.input_voice,
output_metadata = args.output_metadata, output_metadata = args.output_metadata,
model_name = args.model_name, model_name = args.model_name,