From 2424c455cb9614003c072f6cdc25fa80ba2694ba Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 11 Mar 2023 16:32:35 +0000 Subject: [PATCH] added option to not slice audio when transcribing, added option to prepare validation dataset on audio duration, added a warning if youre using whisperx and you're slicing audio --- src/utils.py | 74 ++++++++++++++++++++++++++++++++++++---------------- src/webui.py | 32 +++++++++++++---------- 2 files changed, 69 insertions(+), 37 deletions(-) mode change 100644 => 100755 src/webui.py diff --git a/src/utils.py b/src/utils.py index fb39866..01478ba 100755 --- a/src/utils.py +++ b/src/utils.py @@ -667,7 +667,7 @@ class TrainingState(): self.steps = int(self.info['steps']) if 'iteration_rate' in self.info: - it_rate = self.info['iteration_rate'] + it_rate = self.info['iteration_rate'] / self.batch_size # why self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it' self.it_rates += it_rate @@ -676,6 +676,7 @@ class TrainingState(): eta = str(timedelta(seconds=int(self.eta))) self.eta_hhmmss = eta except Exception as e: + self.eta_hhmmss = "?" pass self.metrics['step'] = [f"{self.epoch}/{self.epochs}"] @@ -1064,13 +1065,16 @@ def whisper_transcribe( file, language=None ): return result -def prepare_dataset( files, outdir, language=None, skip_existings=False, progress=None ): +def prepare_dataset( files, outdir, language=None, skip_existings=False, slice_audio=False, progress=None ): unload_tts() global whisper_model if whisper_model is None: load_whisper_model(language=language) + if args.whisper_backend == "m-bain/whisperx" and slice_audio: + print("! CAUTION ! Slicing audio with whisperx is terrible. Please consider using a different whisper backend if you want to slice audio.") + os.makedirs(f'{outdir}/audio/', exist_ok=True) results = {} @@ -1092,6 +1096,14 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres if match[0] not in previous_list: previous_list.append(f'{match[0].split("/")[-1]}.wav') + def validate_waveform( waveform, sample_rate, name ): + if not torch.any(waveform < 0): + return False + + if waveform.shape[-1] < (.6 * sampling_rate): + return False + return True + for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress): basename = os.path.basename(file) @@ -1106,29 +1118,36 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres waveform, sampling_rate = torchaudio.load(file) num_channels, num_frames = waveform.shape - idx = 0 - for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress): - start = int(segment['start'] * sampling_rate) - end = int(segment['end'] * sampling_rate) - - sliced_waveform = waveform[:, start:end] - sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav") - - if not torch.any(sliced_waveform < 0): - print(f"Sound file is silent: {sliced_name}, skipping...") + if not slice_audio: + if not validate_waveform( waveform, sampling_rate, name ): + print(f"Segment invalid: {name}, skipping...") continue - if sliced_waveform.shape[-1] < (.6 * sampling_rate): - print(f"Sound file is too short: {sliced_name}, skipping...") - continue - - torchaudio.save(f"{outdir}/audio/{sliced_name}", sliced_waveform, sampling_rate) - - idx = idx + 1 - line = f"audio/{sliced_name}|{segment['text'].strip()}" + torchaudio.save(f"{outdir}/audio/{basename}", waveform, sampling_rate) + line = f"audio/{basename}|{result['text'].strip()}" transcription.append(line) with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f: f.write(f'\n{line}') + else: + idx = 0 + for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress): + start = int(segment['start'] * sampling_rate) + end = int(segment['end'] * sampling_rate) + + sliced_waveform = waveform[:, start:end] + sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav") + + if not validate_waveform( sliced_waveform, sampling_rate, sliced_name ): + print(f"Trimmed segment invalid: {sliced_name}, skipping...") + continue + + torchaudio.save(f"{outdir}/audio/{sliced_name}", sliced_waveform, sampling_rate) + + idx = idx + 1 + line = f"audio/{sliced_name}|{segment['text'].strip()}" + transcription.append(line) + with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f: + f.write(f'\n{line}') do_gc() @@ -1144,7 +1163,7 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres return f"Processed dataset to: {outdir}\n{joined}" -def prepare_validation_dataset( voice, text_length ): +def prepare_validation_dataset( voice, text_length, audio_length ): indir = f'./training/{voice}/' infile = f'{indir}/dataset.txt' if not os.path.exists(infile): @@ -1166,8 +1185,14 @@ def prepare_validation_dataset( voice, text_length ): split = line.split("|") filename = split[0] text = split[1] + culled = len(text) < text_length - if len(text) < text_length: + if not culled and audio_length > 0: + metadata = torchaudio.info(f'{indir}/{filename}') + duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate + culled = duration < audio_length + + if culled: validation.append(line.strip()) else: training.append(line.strip()) @@ -1178,7 +1203,7 @@ def prepare_validation_dataset( voice, text_length ): with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f: f.write("\n".join(validation)) - msg = f"Culled {len(validation)} lines" + msg = f"Culled {len(validation)}/{len(lines)} lines." print(msg) return msg @@ -1896,6 +1921,9 @@ def load_tts( restart=False, autoregressive_model=None ): print(f"Loading TorToiSe... (AR: {autoregressive_model}, vocoder: {args.vocoder_model})") + if get_device_name() == "cpu": + print("!!!! WARNING !!!! No GPU available in PyTorch. You may need to reinstall PyTorch.") + tts_loading = True try: tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=autoregressive_model, vocoder_model=args.vocoder_model) diff --git a/src/webui.py b/src/webui.py old mode 100644 new mode 100755 index 1816f14..f1fc5c8 --- a/src/webui.py +++ b/src/webui.py @@ -152,9 +152,7 @@ def import_generate_settings_proxy( file=None ): res = [] for k in GENERATE_SETTINGS_ARGS: res.append(settings[k] if k in settings else None) - print(GENERATE_SETTINGS_ARGS) - print(settings) - print(res) + return tuple(res) def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)): @@ -184,12 +182,12 @@ def read_generate_settings_proxy(file, saveAs='.temp'): gr.update(visible=j is not None), ) -def prepare_dataset_proxy( voice, language, validation_size, skip_existings, progress=gr.Progress(track_tqdm=True) ): +def prepare_dataset_proxy( voice, language, validation_text_length, validation_audio_length, skip_existings, slice_audio, progress=gr.Progress(track_tqdm=True) ): messages = [] - message = prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, skip_existings=skip_existings, progress=progress ) + message = prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, skip_existings=skip_existings, slice_audio=slice_audio, progress=progress ) messages.append(message) - if validation_size > 0: - message = prepare_validation_dataset( voice, text_length=validation_size ) + if validation_text_length > 0 or validation_audio_length > 0: + message = prepare_validation_dataset( voice, text_length=validation_text_length, audio_length=validation_audio_length ) messages.append(message) return "\n".join(messages) @@ -246,8 +244,7 @@ def import_training_settings_proxy( voice ): output[k] = settings[k] output = list(output.values()) - print(list(TRAINING_SETTINGS.keys())) - print(output) + messages.append(f"Imported training settings: {injson}") return output[:-1] + ["\n".join(messages)] @@ -413,13 +410,20 @@ def setup_gradio(): DATASET_SETTINGS['voice'] = gr.Dropdown( choices=voice_list, label="Dataset Source", type="value", value=voice_list[0] if len(voice_list) > 0 else "" ) with gr.Row(): DATASET_SETTINGS['language'] = gr.Textbox(label="Language", value="en") - DATASET_SETTINGS['validation_size'] = gr.Number(label="Validation Text Length Cull Size", value=12, precision=0) - DATASET_SETTINGS['skip'] = gr.Checkbox(label="Skip Already Transcribed", value=False) + DATASET_SETTINGS['validation_text_length'] = gr.Number(label="Validation Text Length Threshold", value=12, precision=0) + DATASET_SETTINGS['validation_audio_length'] = gr.Number(label="Validation Audio Length Threshold", value=1 ) + with gr.Row(): + DATASET_SETTINGS['skip'] = gr.Checkbox(label="Skip Already Transcribed", value=False) + DATASET_SETTINGS['slice'] = gr.Checkbox(label="Slice Segments", value=False) with gr.Row(): transcribe_button = gr.Button(value="Transcribe") prepare_validation_button = gr.Button(value="Prepare Validation") + with gr.Row(): + EXEC_SETTINGS['whisper_backend'] = gr.Dropdown(WHISPER_BACKENDS, label="Whisper Backends", value=args.whisper_backend) + EXEC_SETTINGS['whisper_model'] = gr.Dropdown(WHISPER_MODELS, label="Whisper Model", value=args.whisper_model) + dataset_settings = list(DATASET_SETTINGS.values()) with gr.Column(): prepare_dataset_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) @@ -533,8 +537,7 @@ def setup_gradio(): EXEC_SETTINGS['autoregressive_model'] = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0]) EXEC_SETTINGS['vocoder_model'] = gr.Dropdown(VOCODERS, label="Vocoder", value=args.vocoder_model if args.vocoder_model else VOCODERS[-1]) - EXEC_SETTINGS['whisper_backend'] = gr.Dropdown(WHISPER_BACKENDS, label="Whisper Backends", value=args.whisper_backend) - EXEC_SETTINGS['whisper_model'] = gr.Dropdown(WHISPER_MODELS, label="Whisper Model", value=args.whisper_model) + EXEC_SETTINGS['training_default_halfp'] = TRAINING_SETTINGS['half_p'] EXEC_SETTINGS['training_default_bnb'] = TRAINING_SETTINGS['bitsandbytes'] @@ -739,7 +742,8 @@ def setup_gradio(): prepare_validation_dataset, inputs=[ dataset_settings[0], - DATASET_SETTINGS['validation_size'], + DATASET_SETTINGS['validation_text_length'], + DATASET_SETTINGS['validation_audio_length'], ], outputs=prepare_dataset_output #console_output )