split slicing dataset routine so it can be done after the fact

This commit is contained in:
mrq 2023-03-11 17:27:01 +00:00
parent e3fdb79b49
commit 94551fb9ac
2 changed files with 71 additions and 40 deletions

View File

@ -1051,7 +1051,56 @@ def whisper_transcribe( file, language=None ):
result['segments'].append(reparsed) result['segments'].append(reparsed)
return result return result
def prepare_dataset( files, outdir, language=None, skip_existings=False, slice_audio=False, progress=None ): def validate_waveform( waveform, sample_rate ):
if not torch.any(waveform < 0):
return False
if waveform.shape[-1] < (.6 * sample_rate):
return False
return True
def slice_dataset( voice, start_offset=0, end_offset=0 ):
indir = f'./training/{voice}/'
infile = f'{indir}/whisper.json'
if not os.path.exists(infile):
raise Exception(f"Missing dataset: {infile}")
with open(infile, 'r', encoding="utf-8") as f:
results = json.load(f)
transcription = []
for filename in results:
idx = 0
result = results[filename]
waveform, sampling_rate = torchaudio.load(f'./voices/{voice}/{filename}')
for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
start = int((segment['start'] + start_offset) * sampling_rate)
end = int((segment['end'] + end_offset) * sampling_rate)
sliced_waveform = waveform[:, start:end]
sliced_name = filename.replace(".wav", f"_{pad(idx, 4)}.wav")
if not validate_waveform( sliced_waveform, sampling_rate ):
print(f"Invalid waveform segment ({segment['start']}:{segment['end']}): {sliced_name}, skipping...")
continue
torchaudio.save(f"{indir}/audio/{sliced_name}", sliced_waveform, sampling_rate)
idx = idx + 1
line = f"audio/{sliced_name}|{segment['text'].strip()}"
transcription.append(line)
with open(f'{indir}/train.txt', 'a', encoding="utf-8") as f:
f.write(f'\n{line}')
joined = "\n".join(transcription)
with open(f'{indir}/train.txt', 'w', encoding="utf-8") as f:
f.write(joined)
return f"Processed dataset to: {indir}\n{joined}"
def prepare_dataset( files, outdir, language=None, skip_existings=False, progress=None ):
unload_tts() unload_tts()
global whisper_model global whisper_model
@ -1079,13 +1128,6 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, slice_a
if match[0] not in previous_list: if match[0] not in previous_list:
previous_list.append(f'{match[0].split("/")[-1]}.wav') previous_list.append(f'{match[0].split("/")[-1]}.wav')
def validate_waveform( waveform, sample_rate ):
if not torch.any(waveform < 0):
return False
if waveform.shape[-1] < (.6 * sampling_rate):
return False
return True
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress): for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
basename = os.path.basename(file) basename = os.path.basename(file)
@ -1099,38 +1141,16 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, slice_a
print(f"Transcribed file: {file}, {len(result['segments'])} found.") print(f"Transcribed file: {file}, {len(result['segments'])} found.")
waveform, sampling_rate = torchaudio.load(file) waveform, sampling_rate = torchaudio.load(file)
num_channels, num_frames = waveform.shape
if not slice_audio: if not validate_waveform( waveform, sampling_rate ):
if not validate_waveform( waveform, sampling_rate ): print(f"Invalid waveform: {basename}, skipping...")
print(f"Invalid waveform: {basename}, skipping...") continue
continue
torchaudio.save(f"{outdir}/audio/{basename}", waveform, sampling_rate) torchaudio.save(f"{outdir}/audio/{basename}", waveform, sampling_rate)
line = f"audio/{basename}|{result['text'].strip()}" line = f"audio/{basename}|{result['text'].strip()}"
transcription.append(line) transcription.append(line)
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f: with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
f.write(f'\n{line}') f.write(f'\n{line}')
else:
idx = 0
for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress):
start = int(segment['start'] * sampling_rate)
end = int(segment['end'] * sampling_rate)
sliced_waveform = waveform[:, start:end]
sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav")
if not validate_waveform( sliced_waveform, sampling_rate ):
print(f"Invalid waveform segment ({segment['start']}:{segment['end']}): {sliced_name}, skipping...")
continue
torchaudio.save(f"{outdir}/audio/{sliced_name}", sliced_waveform, sampling_rate)
idx = idx + 1
line = f"audio/{sliced_name}|{segment['text'].strip()}"
transcription.append(line)
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
f.write(f'\n{line}')
do_gc() do_gc()

View File

@ -184,8 +184,11 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, progress=gr.Progress(track_tqdm=True) ): def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, progress=gr.Progress(track_tqdm=True) ):
messages = [] messages = []
message = prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, skip_existings=skip_existings, slice_audio=slice_audio, progress=progress ) message = prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, skip_existings=skip_existings, progress=progress )
messages.append(message) messages.append(message)
if slice_audio:
message = slice_dataset( voice )
messages.append(message)
if validation_text_length > 0 or validation_audio_length > 0: if validation_text_length > 0 or validation_audio_length > 0:
message = prepare_validation_dataset( voice, text_length=validation_text_length, audio_length=validation_audio_length ) message = prepare_validation_dataset( voice, text_length=validation_text_length, audio_length=validation_audio_length )
messages.append(message) messages.append(message)
@ -418,7 +421,8 @@ def setup_gradio():
with gr.Row(): with gr.Row():
transcribe_button = gr.Button(value="Transcribe") transcribe_button = gr.Button(value="Transcribe")
prepare_validation_button = gr.Button(value="Prepare Validation") prepare_validation_button = gr.Button(value="(Re)Create Validation Dataset")
slice_dataset_button = gr.Button(value="(Re)Slice Audio")
with gr.Row(): with gr.Row():
EXEC_SETTINGS['whisper_backend'] = gr.Dropdown(WHISPER_BACKENDS, label="Whisper Backends", value=args.whisper_backend) EXEC_SETTINGS['whisper_backend'] = gr.Dropdown(WHISPER_BACKENDS, label="Whisper Backends", value=args.whisper_backend)
@ -747,6 +751,13 @@ def setup_gradio():
], ],
outputs=prepare_dataset_output #console_output outputs=prepare_dataset_output #console_output
) )
slice_dataset_button.click(
slice_dataset,
inputs=[
dataset_settings[0]
],
outputs=prepare_dataset_output
)
training_refresh_dataset.click( training_refresh_dataset.click(
lambda: gr.update(choices=get_dataset_list()), lambda: gr.update(choices=get_dataset_list()),