From d97639e1389856b3d0e272093c942a9e53b6e19d Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 5 Mar 2023 17:54:36 +0000 Subject: [PATCH] whispercpp actually works now (language loading was weird, slicing needed to divide time by 100), transcribing audio checks for silence and discards them --- src/utils.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/utils.py b/src/utils.py index 145039b..0f361c8 100755 --- a/src/utils.py +++ b/src/utils.py @@ -39,6 +39,7 @@ from tortoise.utils.device import get_device_name, set_device_name MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth" WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"] WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"] +EPOCH_SCHEDULE = [ 9, 18, 25, 33 ] args = None tts = None @@ -997,11 +998,12 @@ def whisper_transcribe( file, language=None ): } for segment in segments: reparsed = { - 'start': segment[0], - 'end': segment[1], + 'start': segment[0] / 100.0, + 'end': segment[1] / 100.0, 'text': segment[2], } result['segments'].append(reparsed) + return result @@ -1014,24 +1016,29 @@ def prepare_dataset( files, outdir, language=None, progress=None ): os.makedirs(outdir, exist_ok=True) - idx = 0 results = {} transcription = [] for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress): + basename = os.path.basename(file) result = whisper_transcribe(file, language=language) - results[os.path.basename(file)] = result + results[basename] = result print(f"Transcribed file: {file}, {len(result['segments'])} found.") waveform, sampling_rate = torchaudio.load(file) num_channels, num_frames = waveform.shape + idx = 0 for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress): start = int(segment['start'] * sampling_rate) end = int(segment['end'] * sampling_rate) sliced_waveform = waveform[:, start:end] - sliced_name = f"{pad(idx, 4)}.wav" + sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav") + + if not torch.any(sliced_waveform < 0): + print(f"Error with {sliced_name}, skipping...") + continue torchaudio.save(f"{outdir}/{sliced_name}", sliced_waveform, sampling_rate) @@ -1056,7 +1063,6 @@ def calc_iterations( epochs, lines, batch_size ): iterations = int(epochs * lines / float(batch_size)) return iterations -EPOCH_SCHEDULE = [ 9, 18, 25, 33 ] def schedule_learning_rate( iterations, schedule=EPOCH_SCHEDULE ): return [int(iterations * d) for d in schedule] @@ -1750,12 +1756,14 @@ def load_whisper_model(language=None, model_name=None, progress=None): print(f"Loading specialized model for language: {language}") notify_progress(f"Loading Whisper model: {model_name}", progress) + if args.whisper_cpp: from whispercpp import Whisper if not language: language = 'auto' - whisper_model = Whisper(model_name, models_dir='./models/', language=language.encode('ascii')) + b_lang = language.encode('ascii') + whisper_model = Whisper(model_name, models_dir='./models/', language=b_lang) else: import whisper whisper_model = whisper.load_model(model_name)