From 94551fb9ac18ba0964606e8b9056701369c727ca Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 11 Mar 2023 17:27:01 +0000 Subject: [PATCH] split slicing dataset routine so it can be done after the fact --- src/utils.py | 96 +++++++++++++++++++++++++++++++--------------------- src/webui.py | 15 ++++++-- 2 files changed, 71 insertions(+), 40 deletions(-) diff --git a/src/utils.py b/src/utils.py index 0e94cea..cc2d4eb 100755 --- a/src/utils.py +++ b/src/utils.py @@ -1051,7 +1051,56 @@ def whisper_transcribe( file, language=None ): result['segments'].append(reparsed) return result -def prepare_dataset( files, outdir, language=None, skip_existings=False, slice_audio=False, progress=None ): +def validate_waveform( waveform, sample_rate ): + if not torch.any(waveform < 0): + return False + + if waveform.shape[-1] < (.6 * sample_rate): + return False + return True + +def slice_dataset( voice, start_offset=0, end_offset=0 ): + indir = f'./training/{voice}/' + infile = f'{indir}/whisper.json' + + if not os.path.exists(infile): + raise Exception(f"Missing dataset: {infile}") + + with open(infile, 'r', encoding="utf-8") as f: + results = json.load(f) + + transcription = [] + for filename in results: + idx = 0 + result = results[filename] + waveform, sampling_rate = torchaudio.load(f'./voices/{voice}/{filename}') + + for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress): + start = int((segment['start'] + start_offset) * sampling_rate) + end = int((segment['end'] + end_offset) * sampling_rate) + + sliced_waveform = waveform[:, start:end] + sliced_name = filename.replace(".wav", f"_{pad(idx, 4)}.wav") + + if not validate_waveform( sliced_waveform, sampling_rate ): + print(f"Invalid waveform segment ({segment['start']}:{segment['end']}): {sliced_name}, skipping...") + continue + + torchaudio.save(f"{indir}/audio/{sliced_name}", sliced_waveform, sampling_rate) + + idx = idx + 1 + line = f"audio/{sliced_name}|{segment['text'].strip()}" + transcription.append(line) + with open(f'{indir}/train.txt', 'a', encoding="utf-8") as f: + f.write(f'\n{line}') + + joined = "\n".join(transcription) + with open(f'{indir}/train.txt', 'w', encoding="utf-8") as f: + f.write(joined) + + return f"Processed dataset to: {indir}\n{joined}" + +def prepare_dataset( files, outdir, language=None, skip_existings=False, progress=None ): unload_tts() global whisper_model @@ -1079,13 +1128,6 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, slice_a if match[0] not in previous_list: previous_list.append(f'{match[0].split("/")[-1]}.wav') - def validate_waveform( waveform, sample_rate ): - if not torch.any(waveform < 0): - return False - - if waveform.shape[-1] < (.6 * sampling_rate): - return False - return True for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress): basename = os.path.basename(file) @@ -1099,38 +1141,16 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, slice_a print(f"Transcribed file: {file}, {len(result['segments'])} found.") waveform, sampling_rate = torchaudio.load(file) - num_channels, num_frames = waveform.shape - if not slice_audio: - if not validate_waveform( waveform, sampling_rate ): - print(f"Invalid waveform: {basename}, skipping...") - continue + if not validate_waveform( waveform, sampling_rate ): + print(f"Invalid waveform: {basename}, skipping...") + continue - torchaudio.save(f"{outdir}/audio/{basename}", waveform, sampling_rate) - line = f"audio/{basename}|{result['text'].strip()}" - transcription.append(line) - with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f: - f.write(f'\n{line}') - else: - idx = 0 - for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress): - start = int(segment['start'] * sampling_rate) - end = int(segment['end'] * sampling_rate) - - sliced_waveform = waveform[:, start:end] - sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav") - - if not validate_waveform( sliced_waveform, sampling_rate ): - print(f"Invalid waveform segment ({segment['start']}:{segment['end']}): {sliced_name}, skipping...") - continue - - torchaudio.save(f"{outdir}/audio/{sliced_name}", sliced_waveform, sampling_rate) - - idx = idx + 1 - line = f"audio/{sliced_name}|{segment['text'].strip()}" - transcription.append(line) - with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f: - f.write(f'\n{line}') + torchaudio.save(f"{outdir}/audio/{basename}", waveform, sampling_rate) + line = f"audio/{basename}|{result['text'].strip()}" + transcription.append(line) + with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f: + f.write(f'\n{line}') do_gc() diff --git a/src/webui.py b/src/webui.py index f1fc5c8..c4e0190 100755 --- a/src/webui.py +++ b/src/webui.py @@ -184,8 +184,11 @@ def read_generate_settings_proxy(file, saveAs='.temp'): def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, progress=gr.Progress(track_tqdm=True) ): messages = [] - message = prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, skip_existings=skip_existings, slice_audio=slice_audio, progress=progress ) + message = prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, skip_existings=skip_existings, progress=progress ) messages.append(message) + if slice_audio: + message = slice_dataset( voice ) + messages.append(message) if validation_text_length > 0 or validation_audio_length > 0: message = prepare_validation_dataset( voice, text_length=validation_text_length, audio_length=validation_audio_length ) messages.append(message) @@ -418,7 +421,8 @@ def setup_gradio(): with gr.Row(): transcribe_button = gr.Button(value="Transcribe") - prepare_validation_button = gr.Button(value="Prepare Validation") + prepare_validation_button = gr.Button(value="(Re)Create Validation Dataset") + slice_dataset_button = gr.Button(value="(Re)Slice Audio") with gr.Row(): EXEC_SETTINGS['whisper_backend'] = gr.Dropdown(WHISPER_BACKENDS, label="Whisper Backends", value=args.whisper_backend) @@ -747,6 +751,13 @@ def setup_gradio(): ], outputs=prepare_dataset_output #console_output ) + slice_dataset_button.click( + slice_dataset, + inputs=[ + dataset_settings[0] + ], + outputs=prepare_dataset_output + ) training_refresh_dataset.click( lambda: gr.update(choices=get_dataset_list()),