diff --git a/src/utils.py b/src/utils.py index d549145..8fe8fc6 100755 --- a/src/utils.py +++ b/src/utils.py @@ -33,7 +33,7 @@ from datetime import datetime from datetime import timedelta from tortoise.api import TextToSpeech, MODELS, get_model_path, pad_or_truncate -from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir +from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir, get_voices from tortoise.utils.text import split_and_recombine_text from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram @@ -1059,155 +1059,119 @@ def validate_waveform( waveform, sample_rate ): return False return True -def slice_dataset( voice, start_offset=0, end_offset=0 ): - indir = f'./training/{voice}/' - infile = f'{indir}/whisper.json' - - if not os.path.exists(infile): - raise Exception(f"Missing dataset: {infile}") - - with open(infile, 'r', encoding="utf-8") as f: - results = json.load(f) - - transcription = [] - for filename in results: - idx = 0 - result = results[filename] - waveform, sampling_rate = torchaudio.load(f'./voices/{voice}/{filename}') - - for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress): - start = int((segment['start'] + start_offset) * sampling_rate) - end = int((segment['end'] + end_offset) * sampling_rate) - - sliced_waveform = waveform[:, start:end] - sliced_name = filename.replace(".wav", f"_{pad(idx, 4)}.wav") - - if not validate_waveform( sliced_waveform, sampling_rate ): - print(f"Invalid waveform segment ({segment['start']}:{segment['end']}): {sliced_name}, skipping...") - continue - - torchaudio.save(f"{indir}/audio/{sliced_name}", sliced_waveform, sampling_rate) - - idx = idx + 1 - line = f"audio/{sliced_name}|{segment['text'].strip()}" - transcription.append(line) - with open(f'{indir}/train.txt', 'a', encoding="utf-8") as f: - f.write(f'\n{line}') - - joined = "\n".join(transcription) - with open(f'{indir}/train.txt', 'w', encoding="utf-8") as f: - f.write(joined) - - return f"Processed dataset to: {indir}\n{joined}" - -def prepare_dataset( files, outdir, language=None, skip_existings=False, progress=None ): +def transcribe_dataset( voice, language=None, skip_existings=False, progress=None ): unload_tts() global whisper_model if whisper_model is None: load_whisper_model(language=language) - os.makedirs(f'{outdir}/audio/', exist_ok=True) results = {} - transcription = [] - files = sorted(files) - - previous_list = [] - if skip_existings and os.path.exists(f'{outdir}/train.txt'): - parsed_list = [] - with open(f'{outdir}/train.txt', 'r', encoding="utf-8") as f: - parsed_list = f.readlines() - - for line in parsed_list: - match = re.findall(r"^(.+?)_\d+\.wav$", line.split("|")[0]) - - if match is None or len(match) == 0: - continue - - if match[0] not in previous_list: - previous_list.append(f'{match[0].split("/")[-1]}.wav') + files = sorted( get_voices(load_latents=False)[voice] ) + indir = f'./training/{voice}/' + infile = f'{indir}/whisper.json' + + os.makedirs(f'{indir}/audio/', exist_ok=True) + + if os.path.exists(infile): + results = json.load(open(infile, 'r', encoding="utf-8")) for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress): basename = os.path.basename(file) - if basename in previous_list: + if basename in results and skip_existings: print(f"Skipping already parsed file: {basename}") continue - result = whisper_transcribe(file, language=language) - results[basename] = result - print(f"Transcribed file: {file}, {len(result['segments'])} found.") + results[basename] = whisper_transcribe(file, language=language) + # lazy copy waveform, sampling_rate = torchaudio.load(file) + torchaudio.save(f"{indir}/audio/{basename}", waveform, sampling_rate) - if not validate_waveform( waveform, sampling_rate ): - print(f"Invalid waveform: {basename}, skipping...") - continue - - torchaudio.save(f"{outdir}/audio/{basename}", waveform, sampling_rate) - line = f"audio/{basename}|{result['text'].strip()}" - transcription.append(line) - with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f: - f.write(f'\n{line}') + with open(infile, 'w', encoding="utf-8") as f: + f.write(json.dumps(results, indent='\t')) do_gc() - - with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f: - f.write(json.dumps(results, indent='\t')) unload_whisper() - joined = "\n".join(transcription) - if not skip_existings: - with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f: - f.write(joined) - - return f"Processed dataset to: {outdir}\n{joined}" + return f"Processed dataset to: {indir}" -def prepare_validation_dataset( voice, text_length, audio_length ): +def slice_dataset( voice, start_offset=0, end_offset=0 ): indir = f'./training/{voice}/' - infile = f'{indir}/dataset.txt' + infile = f'{indir}/whisper.json' + if not os.path.exists(infile): - infile = f'{indir}/train.txt' - with open(f'{indir}/train.txt', 'r', encoding="utf-8") as src: - with open(f'{indir}/dataset.txt', 'w', encoding="utf-8") as dst: - dst.write(src.read()) + raise Exception(f"Missing dataset: {infile}") + + results = json.load(open(infile, 'r', encoding="utf-8")) + + files = 0 + segments = 0 + for filename in results: + files += 1 + + result = results[filename] + waveform, sampling_rate = torchaudio.load(f'./voices/{voice}/{filename}') + + for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress): + segments +=1 + start = int((segment['start'] + start_offset) * sampling_rate) + end = int((segment['end'] + end_offset) * sampling_rate) + + sliced = waveform[:, start:end] + file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") + + if not validate_waveform( sliced, sampling_rate ): + print(f"Invalid waveform segment ({segment['start']}:{segment['end']}): {file}, skipping...") + continue + + torchaudio.save(f"{indir}/audio/{file}", sliced, sampling_rate) + + return f"Sliced segments: {files} => {segments}." + +def prepare_dataset( voice, use_segments, text_length, audio_length ): + indir = f'./training/{voice}/' + infile = f'{indir}/whisper.json' if not os.path.exists(infile): raise Exception(f"Missing dataset: {infile}") - with open(infile, 'r', encoding="utf-8") as f: - lines = f.readlines() + results = json.load(open(infile, 'r', encoding="utf-8")) + + lines = { + 'training': [], + 'validation': [], + } - validation = [] - training = [] + for filename in results: + result = results[filename] + segments = result['segments'] if use_segments else [{'text': result['text']}] + for segment in segments: + text = segment['text'].strip() + file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") if use_segments else filename - for line in lines: - split = line.split("|") - filename = split[0] - text = split[1] - culled = len(text) < text_length + culled = len(text) < text_length + if not culled and audio_length > 0: + metadata = torchaudio.info(f'{indir}/audio/{file}') + duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate + culled = duration < audio_length - if not culled and audio_length > 0: - metadata = torchaudio.info(f'{indir}/{filename}') - duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate - culled = duration < audio_length + lines['training' if not culled else 'validation'].append(f'audio/{file}|{text}') - if culled: - validation.append(line.strip()) - else: - training.append(line.strip()) + training_joined = "\n".join(lines['training']) + validation_joined = "\n".join(lines['validation']) with open(f'{indir}/train.txt', 'w', encoding="utf-8") as f: - f.write("\n".join(training)) + f.write(training_joined) with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f: - f.write("\n".join(validation)) + f.write(validation_joined) - msg = f"Culled {len(validation)}/{len(lines)} lines." - print(msg) + msg = f"Prepared {len(lines['training'])} lines (validation: {len(lines['validation'])}).\n{training_joined}\n\n{validation_joined}" return msg def calc_iterations( epochs, lines, batch_size ): diff --git a/src/webui.py b/src/webui.py index c4e0190..82bec2f 100755 --- a/src/webui.py +++ b/src/webui.py @@ -182,16 +182,19 @@ def read_generate_settings_proxy(file, saveAs='.temp'): gr.update(visible=j is not None), ) -def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, progress=gr.Progress(track_tqdm=True) ): +def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, progress=gr.Progress(track_tqdm=False) ): messages = [] - message = prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, skip_existings=skip_existings, progress=progress ) + + message = transcribe_dataset( voice=voice, language=language, skip_existings=skip_existings, progress=progress ) messages.append(message) + if slice_audio: message = slice_dataset( voice ) messages.append(message) - if validation_text_length > 0 or validation_audio_length > 0: - message = prepare_validation_dataset( voice, text_length=validation_text_length, audio_length=validation_audio_length ) - messages.append(message) + + message = prepare_dataset( voice, use_segments=slice_audio, text_length=validation_text_length, audio_length=validation_audio_length ) + messages.append(message) + return "\n".join(messages) def update_args_proxy( *args ): @@ -421,8 +424,8 @@ def setup_gradio(): with gr.Row(): transcribe_button = gr.Button(value="Transcribe") - prepare_validation_button = gr.Button(value="(Re)Create Validation Dataset") slice_dataset_button = gr.Button(value="(Re)Slice Audio") + prepare_dataset_button = gr.Button(value="(Re)Create Dataset") with gr.Row(): EXEC_SETTINGS['whisper_backend'] = gr.Dropdown(WHISPER_BACKENDS, label="Whisper Backends", value=args.whisper_backend) @@ -654,7 +657,7 @@ def setup_gradio(): inputs=None, outputs=[ GENERATE_SETTINGS['voice'], - dataset_settings[0], + DATASET_SETTINGS['voice'], history_voices ] ) @@ -742,10 +745,11 @@ def setup_gradio(): inputs=dataset_settings, outputs=prepare_dataset_output #console_output ) - prepare_validation_button.click( - prepare_validation_dataset, + prepare_dataset_button.click( + prepare_dataset, inputs=[ - dataset_settings[0], + DATASET_SETTINGS['voice'], + DATASET_SETTINGS['slice'], DATASET_SETTINGS['validation_text_length'], DATASET_SETTINGS['validation_audio_length'], ], @@ -754,7 +758,7 @@ def setup_gradio(): slice_dataset_button.click( slice_dataset, inputs=[ - dataset_settings[0] + DATASET_SETTINGS['voice'] ], outputs=prepare_dataset_output )