From 382a3e41048f0c371d525ed19464165a8cd5a63a Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 11 Mar 2023 21:17:11 +0000 Subject: [PATCH] rely on the whisper.json for handling a lot more things --- src/utils.py | 194 +++++++++++++++++++++------------------------------ src/webui.py | 26 ++++--- 2 files changed, 94 insertions(+), 126 deletions(-) diff --git a/src/utils.py b/src/utils.py index d549145..8fe8fc6 100755 --- a/src/utils.py +++ b/src/utils.py @@ -33,7 +33,7 @@ from datetime import datetime from datetime import timedelta from tortoise.api import TextToSpeech, MODELS, get_model_path, pad_or_truncate -from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir +from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir, get_voices from tortoise.utils.text import split_and_recombine_text from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram @@ -1059,6 +1059,47 @@ def validate_waveform( waveform, sample_rate ): return False return True +def transcribe_dataset( voice, language=None, skip_existings=False, progress=None ): + unload_tts() + + global whisper_model + if whisper_model is None: + load_whisper_model(language=language) + + + results = {} + + files = sorted( get_voices(load_latents=False)[voice] ) + indir = f'./training/{voice}/' + infile = f'{indir}/whisper.json' + + os.makedirs(f'{indir}/audio/', exist_ok=True) + + if os.path.exists(infile): + results = json.load(open(infile, 'r', encoding="utf-8")) + + for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress): + basename = os.path.basename(file) + + if basename in results and skip_existings: + print(f"Skipping already parsed file: {basename}") + continue + + results[basename] = whisper_transcribe(file, language=language) + + # lazy copy + waveform, sampling_rate = torchaudio.load(file) + torchaudio.save(f"{indir}/audio/{basename}", waveform, sampling_rate) + + with open(infile, 'w', encoding="utf-8") as f: + f.write(json.dumps(results, indent='\t')) + + do_gc() + + unload_whisper() + + return f"Processed dataset to: {indir}" + def slice_dataset( voice, start_offset=0, end_offset=0 ): indir = f'./training/{voice}/' infile = f'{indir}/whisper.json' @@ -1066,148 +1107,71 @@ def slice_dataset( voice, start_offset=0, end_offset=0 ): if not os.path.exists(infile): raise Exception(f"Missing dataset: {infile}") - with open(infile, 'r', encoding="utf-8") as f: - results = json.load(f) + results = json.load(open(infile, 'r', encoding="utf-8")) - transcription = [] + files = 0 + segments = 0 for filename in results: - idx = 0 + files += 1 + result = results[filename] waveform, sampling_rate = torchaudio.load(f'./voices/{voice}/{filename}') for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress): + segments +=1 start = int((segment['start'] + start_offset) * sampling_rate) end = int((segment['end'] + end_offset) * sampling_rate) - sliced_waveform = waveform[:, start:end] - sliced_name = filename.replace(".wav", f"_{pad(idx, 4)}.wav") + sliced = waveform[:, start:end] + file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") - if not validate_waveform( sliced_waveform, sampling_rate ): - print(f"Invalid waveform segment ({segment['start']}:{segment['end']}): {sliced_name}, skipping...") + if not validate_waveform( sliced, sampling_rate ): + print(f"Invalid waveform segment ({segment['start']}:{segment['end']}): {file}, skipping...") continue - torchaudio.save(f"{indir}/audio/{sliced_name}", sliced_waveform, sampling_rate) + torchaudio.save(f"{indir}/audio/{file}", sliced, sampling_rate) - idx = idx + 1 - line = f"audio/{sliced_name}|{segment['text'].strip()}" - transcription.append(line) - with open(f'{indir}/train.txt', 'a', encoding="utf-8") as f: - f.write(f'\n{line}') + return f"Sliced segments: {files} => {segments}." - joined = "\n".join(transcription) - with open(f'{indir}/train.txt', 'w', encoding="utf-8") as f: - f.write(joined) - - return f"Processed dataset to: {indir}\n{joined}" - -def prepare_dataset( files, outdir, language=None, skip_existings=False, progress=None ): - unload_tts() - - global whisper_model - if whisper_model is None: - load_whisper_model(language=language) - - os.makedirs(f'{outdir}/audio/', exist_ok=True) - - results = {} - transcription = [] - files = sorted(files) - - previous_list = [] - if skip_existings and os.path.exists(f'{outdir}/train.txt'): - parsed_list = [] - with open(f'{outdir}/train.txt', 'r', encoding="utf-8") as f: - parsed_list = f.readlines() - - for line in parsed_list: - match = re.findall(r"^(.+?)_\d+\.wav$", line.split("|")[0]) - - if match is None or len(match) == 0: - continue - - if match[0] not in previous_list: - previous_list.append(f'{match[0].split("/")[-1]}.wav') - - - for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress): - basename = os.path.basename(file) - - if basename in previous_list: - print(f"Skipping already parsed file: {basename}") - continue - - result = whisper_transcribe(file, language=language) - results[basename] = result - print(f"Transcribed file: {file}, {len(result['segments'])} found.") - - waveform, sampling_rate = torchaudio.load(file) - - if not validate_waveform( waveform, sampling_rate ): - print(f"Invalid waveform: {basename}, skipping...") - continue - - torchaudio.save(f"{outdir}/audio/{basename}", waveform, sampling_rate) - line = f"audio/{basename}|{result['text'].strip()}" - transcription.append(line) - with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f: - f.write(f'\n{line}') - - do_gc() - - with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f: - f.write(json.dumps(results, indent='\t')) - - unload_whisper() - - joined = "\n".join(transcription) - if not skip_existings: - with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f: - f.write(joined) - - return f"Processed dataset to: {outdir}\n{joined}" - -def prepare_validation_dataset( voice, text_length, audio_length ): +def prepare_dataset( voice, use_segments, text_length, audio_length ): indir = f'./training/{voice}/' - infile = f'{indir}/dataset.txt' - if not os.path.exists(infile): - infile = f'{indir}/train.txt' - with open(f'{indir}/train.txt', 'r', encoding="utf-8") as src: - with open(f'{indir}/dataset.txt', 'w', encoding="utf-8") as dst: - dst.write(src.read()) + infile = f'{indir}/whisper.json' if not os.path.exists(infile): raise Exception(f"Missing dataset: {infile}") - with open(infile, 'r', encoding="utf-8") as f: - lines = f.readlines() + results = json.load(open(infile, 'r', encoding="utf-8")) - validation = [] - training = [] + lines = { + 'training': [], + 'validation': [], + } - for line in lines: - split = line.split("|") - filename = split[0] - text = split[1] - culled = len(text) < text_length + for filename in results: + result = results[filename] + segments = result['segments'] if use_segments else [{'text': result['text']}] + for segment in segments: + text = segment['text'].strip() + file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") if use_segments else filename - if not culled and audio_length > 0: - metadata = torchaudio.info(f'{indir}/{filename}') - duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate - culled = duration < audio_length + culled = len(text) < text_length + if not culled and audio_length > 0: + metadata = torchaudio.info(f'{indir}/audio/{file}') + duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate + culled = duration < audio_length - if culled: - validation.append(line.strip()) - else: - training.append(line.strip()) + lines['training' if not culled else 'validation'].append(f'audio/{file}|{text}') + + training_joined = "\n".join(lines['training']) + validation_joined = "\n".join(lines['validation']) with open(f'{indir}/train.txt', 'w', encoding="utf-8") as f: - f.write("\n".join(training)) + f.write(training_joined) with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f: - f.write("\n".join(validation)) + f.write(validation_joined) - msg = f"Culled {len(validation)}/{len(lines)} lines." - print(msg) + msg = f"Prepared {len(lines['training'])} lines (validation: {len(lines['validation'])}).\n{training_joined}\n\n{validation_joined}" return msg def calc_iterations( epochs, lines, batch_size ): diff --git a/src/webui.py b/src/webui.py index c4e0190..82bec2f 100755 --- a/src/webui.py +++ b/src/webui.py @@ -182,16 +182,19 @@ def read_generate_settings_proxy(file, saveAs='.temp'): gr.update(visible=j is not None), ) -def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, progress=gr.Progress(track_tqdm=True) ): +def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, progress=gr.Progress(track_tqdm=False) ): messages = [] - message = prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, skip_existings=skip_existings, progress=progress ) + + message = transcribe_dataset( voice=voice, language=language, skip_existings=skip_existings, progress=progress ) messages.append(message) + if slice_audio: message = slice_dataset( voice ) messages.append(message) - if validation_text_length > 0 or validation_audio_length > 0: - message = prepare_validation_dataset( voice, text_length=validation_text_length, audio_length=validation_audio_length ) - messages.append(message) + + message = prepare_dataset( voice, use_segments=slice_audio, text_length=validation_text_length, audio_length=validation_audio_length ) + messages.append(message) + return "\n".join(messages) def update_args_proxy( *args ): @@ -421,8 +424,8 @@ def setup_gradio(): with gr.Row(): transcribe_button = gr.Button(value="Transcribe") - prepare_validation_button = gr.Button(value="(Re)Create Validation Dataset") slice_dataset_button = gr.Button(value="(Re)Slice Audio") + prepare_dataset_button = gr.Button(value="(Re)Create Dataset") with gr.Row(): EXEC_SETTINGS['whisper_backend'] = gr.Dropdown(WHISPER_BACKENDS, label="Whisper Backends", value=args.whisper_backend) @@ -654,7 +657,7 @@ def setup_gradio(): inputs=None, outputs=[ GENERATE_SETTINGS['voice'], - dataset_settings[0], + DATASET_SETTINGS['voice'], history_voices ] ) @@ -742,10 +745,11 @@ def setup_gradio(): inputs=dataset_settings, outputs=prepare_dataset_output #console_output ) - prepare_validation_button.click( - prepare_validation_dataset, + prepare_dataset_button.click( + prepare_dataset, inputs=[ - dataset_settings[0], + DATASET_SETTINGS['voice'], + DATASET_SETTINGS['slice'], DATASET_SETTINGS['validation_text_length'], DATASET_SETTINGS['validation_audio_length'], ], @@ -754,7 +758,7 @@ def setup_gradio(): slice_dataset_button.click( slice_dataset, inputs=[ - dataset_settings[0] + DATASET_SETTINGS['voice'] ], outputs=prepare_dataset_output )