fix vall_e.emb.transcriber

This commit is contained in:
mrq 2024-10-08 19:24:43 -05:00
parent acdce66d4e
commit 0656a762af
2 changed files with 45 additions and 14 deletions

View File

@ -87,21 +87,22 @@ def process_jobs( jobs, speaker_id="", raise_exceptions=True ):
continue
def process(
audio_backend="encodec",
input_audio="voices",
input_metadata="metadata",
output_dataset="training",
raise_exceptions=False,
stride=0,
stride_offset=0,
slice="auto",
audio_backend="encodec",
input_audio="voices",
input_voice=None,
input_metadata="metadata",
output_dataset="training",
raise_exceptions=False,
stride=0,
stride_offset=0,
slice="auto",
low_memory=False,
low_memory=False,
device="cuda",
dtype="float16",
amp=False,
):
device="cuda",
dtype="float16",
amp=False,
):
# prepare from args
cfg.set_audio_backend(audio_backend)
audio_extension = cfg.audio_backend_extension
@ -129,6 +130,9 @@ def process(
}
dataset = []
if input_voice is not None:
only_speakers = [input_voice]
for group_name in sorted(os.listdir(f'./{input_audio}/')):
if not os.path.isdir(f'./{input_audio}/{group_name}/'):
_logger.warning(f'Is not dir:" /{input_audio}/{group_name}/')
@ -255,6 +259,7 @@ def main():
parser.add_argument("--audio-backend", type=str, default="encodec")
parser.add_argument("--input-audio", type=str, default="voices")
parser.add_argument("--input-voice", type=str, default=None)
parser.add_argument("--input-metadata", type=str, default="training/metadata")
parser.add_argument("--output-dataset", type=str, default="training/dataset")
parser.add_argument("--raise-exceptions", action="store_true")
@ -279,6 +284,7 @@ def main():
process(
audio_backend=args.audio_backend,
input_audio=args.input_audio,
input_voice=args.input_voice,
input_metadata=args.input_metadata,
output_dataset=args.output_dataset,
raise_exceptions=args.raise_exceptions,

View File

@ -23,6 +23,7 @@ def process_items( items, stride=0, stride_offset=0 ):
def transcribe(
input_audio = "voices",
input_voice = None,
output_metadata = "training/metadata",
model_name = "large-v3",
@ -30,12 +31,24 @@ def transcribe(
diarize = False,
stride = 0,
stride_offset = ,
stride_offset = 0,
batch_size = 16,
device = "cuda",
dtype = "float16",
):
# to-do: make this also prepared from args
language_map = {} # k = group, v = language
ignore_groups = [] # skip these groups
ignore_speakers = [] # skip these speakers
only_groups = [] # only process these groups
only_speakers = [] # only process these speakers
if input_voice is not None:
only_speakers = [input_voice]
#
model = whisperx.load_model(model_name, device, compute_type=dtype)
align_model, align_model_metadata, align_model_language = (None, None, None)
@ -49,10 +62,20 @@ def transcribe(
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
continue
if group_name in ignore_groups:
continue
if only_groups and group_name not in only_groups:
continue
for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{dataset_name}/')), desc="Processing speaker"):
if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'):
continue
if speaker_id in ignore_speakers:
continue
if only_speakers and speaker_id not in only_speakers:
continue
outpath = Path(f'./{output_metadata}/{dataset_name}/{speaker_id}/whisper.json')
if outpath.exists():
@ -122,6 +145,7 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input-audio", type=str, default="voices")
parser.add_argument("--input-voice", type=str, default=None)
parser.add_argument("--output-metadata", type=str, default="training/metadata")
parser.add_argument("--model-name", type=str, default="large-v3")
@ -147,6 +171,7 @@ def main():
transcribe(
input_audio = args.input_audio,
input_voice = args.input_voice,
output_metadata = args.output_metadata,
model_name = args.model_name,