fix vall_e.emb.transcriber
This commit is contained in:
parent
acdce66d4e
commit
0656a762af
|
@ -87,21 +87,22 @@ def process_jobs( jobs, speaker_id="", raise_exceptions=True ):
|
|||
continue
|
||||
|
||||
def process(
|
||||
audio_backend="encodec",
|
||||
input_audio="voices",
|
||||
input_metadata="metadata",
|
||||
output_dataset="training",
|
||||
raise_exceptions=False,
|
||||
stride=0,
|
||||
stride_offset=0,
|
||||
slice="auto",
|
||||
audio_backend="encodec",
|
||||
input_audio="voices",
|
||||
input_voice=None,
|
||||
input_metadata="metadata",
|
||||
output_dataset="training",
|
||||
raise_exceptions=False,
|
||||
stride=0,
|
||||
stride_offset=0,
|
||||
slice="auto",
|
||||
|
||||
low_memory=False,
|
||||
low_memory=False,
|
||||
|
||||
device="cuda",
|
||||
dtype="float16",
|
||||
amp=False,
|
||||
):
|
||||
device="cuda",
|
||||
dtype="float16",
|
||||
amp=False,
|
||||
):
|
||||
# prepare from args
|
||||
cfg.set_audio_backend(audio_backend)
|
||||
audio_extension = cfg.audio_backend_extension
|
||||
|
@ -129,6 +130,9 @@ def process(
|
|||
}
|
||||
dataset = []
|
||||
|
||||
if input_voice is not None:
|
||||
only_speakers = [input_voice]
|
||||
|
||||
for group_name in sorted(os.listdir(f'./{input_audio}/')):
|
||||
if not os.path.isdir(f'./{input_audio}/{group_name}/'):
|
||||
_logger.warning(f'Is not dir:" /{input_audio}/{group_name}/')
|
||||
|
@ -255,6 +259,7 @@ def main():
|
|||
|
||||
parser.add_argument("--audio-backend", type=str, default="encodec")
|
||||
parser.add_argument("--input-audio", type=str, default="voices")
|
||||
parser.add_argument("--input-voice", type=str, default=None)
|
||||
parser.add_argument("--input-metadata", type=str, default="training/metadata")
|
||||
parser.add_argument("--output-dataset", type=str, default="training/dataset")
|
||||
parser.add_argument("--raise-exceptions", action="store_true")
|
||||
|
@ -279,6 +284,7 @@ def main():
|
|||
process(
|
||||
audio_backend=args.audio_backend,
|
||||
input_audio=args.input_audio,
|
||||
input_voice=args.input_voice,
|
||||
input_metadata=args.input_metadata,
|
||||
output_dataset=args.output_dataset,
|
||||
raise_exceptions=args.raise_exceptions,
|
||||
|
|
|
@ -23,6 +23,7 @@ def process_items( items, stride=0, stride_offset=0 ):
|
|||
|
||||
def transcribe(
|
||||
input_audio = "voices",
|
||||
input_voice = None,
|
||||
output_metadata = "training/metadata",
|
||||
model_name = "large-v3",
|
||||
|
||||
|
@ -30,12 +31,24 @@ def transcribe(
|
|||
diarize = False,
|
||||
|
||||
stride = 0,
|
||||
stride_offset = ,
|
||||
stride_offset = 0,
|
||||
|
||||
batch_size = 16,
|
||||
device = "cuda",
|
||||
dtype = "float16",
|
||||
):
|
||||
# to-do: make this also prepared from args
|
||||
language_map = {} # k = group, v = language
|
||||
|
||||
ignore_groups = [] # skip these groups
|
||||
ignore_speakers = [] # skip these speakers
|
||||
|
||||
only_groups = [] # only process these groups
|
||||
only_speakers = [] # only process these speakers
|
||||
|
||||
if input_voice is not None:
|
||||
only_speakers = [input_voice]
|
||||
|
||||
#
|
||||
model = whisperx.load_model(model_name, device, compute_type=dtype)
|
||||
align_model, align_model_metadata, align_model_language = (None, None, None)
|
||||
|
@ -49,10 +62,20 @@ def transcribe(
|
|||
if not os.path.isdir(f'./{input_audio}/{dataset_name}/'):
|
||||
continue
|
||||
|
||||
if group_name in ignore_groups:
|
||||
continue
|
||||
if only_groups and group_name not in only_groups:
|
||||
continue
|
||||
|
||||
for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{dataset_name}/')), desc="Processing speaker"):
|
||||
if not os.path.isdir(f'./{input_audio}/{dataset_name}/{speaker_id}'):
|
||||
continue
|
||||
|
||||
if speaker_id in ignore_speakers:
|
||||
continue
|
||||
if only_speakers and speaker_id not in only_speakers:
|
||||
continue
|
||||
|
||||
outpath = Path(f'./{output_metadata}/{dataset_name}/{speaker_id}/whisper.json')
|
||||
|
||||
if outpath.exists():
|
||||
|
@ -122,6 +145,7 @@ def main():
|
|||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--input-audio", type=str, default="voices")
|
||||
parser.add_argument("--input-voice", type=str, default=None)
|
||||
parser.add_argument("--output-metadata", type=str, default="training/metadata")
|
||||
|
||||
parser.add_argument("--model-name", type=str, default="large-v3")
|
||||
|
@ -147,6 +171,7 @@ def main():
|
|||
|
||||
transcribe(
|
||||
input_audio = args.input_audio,
|
||||
input_voice = args.input_voice,
|
||||
output_metadata = args.output_metadata,
|
||||
model_name = args.model_name,
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user