From 050bcefd730911f64cc354425a2f2a9896a9fbde Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 13 Mar 2023 01:20:55 +0000 Subject: [PATCH] resample to 22.5K when creating training inputs (to avoid redundant downsampling when loaded for training, even though most of my inputs are already at 22.5K), generalized resampler function to cache and reuse them, do not unload whisper when done transcribing since it gets unloaded anyways for any other non-transcription task --- src/utils.py | 66 ++++++++++++++++++++++++---------------------------- 1 file changed, 31 insertions(+), 35 deletions(-) diff --git a/src/utils.py b/src/utils.py index 466460c..508a7e6 100755 --- a/src/utils.py +++ b/src/utils.py @@ -49,6 +49,8 @@ GENERATE_SETTINGS_ARGS = None LEARNING_RATE_SCHEMES = {"Multistep": "MultiStepLR", "Cos. Annealing": "CosineAnnealingLR_Restart"} LEARNING_RATE_SCHEDULE = [ 9, 18, 25, 33 ] +RESAMPLERS = {} + args = None tts = None tts_loading = False @@ -59,6 +61,23 @@ training_state = None current_voice = None +def resample( waveform, input_rate, output_rate=44100 ): + if input_rate == output_rate: + return waveform, output_rate + + key = f'{input_rate}:{output_rate}' + if not key in RESAMPLERS: + RESAMPLERS[key] = torchaudio.transforms.Resample( + input_rate, + output_rate, + lowpass_filter_width=16, + rolloff=0.85, + resampling_method="kaiser_window", + beta=8.555504641634386, + ) + + return RESAMPLERS[key]( waveform ), output_rate + def generate(**kwargs): parameters = {} parameters.update(kwargs) @@ -199,17 +218,6 @@ def generate(**kwargs): os.makedirs(outdir, exist_ok=True) audio_cache = {} - resample = None - - if tts.output_sample_rate != args.output_sample_rate: - resampler = torchaudio.transforms.Resample( - tts.output_sample_rate, - args.output_sample_rate, - lowpass_filter_width=16, - rolloff=0.85, - resampling_method="kaiser_window", - beta=8.555504641634386, - ) volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None @@ -343,8 +351,7 @@ def generate(**kwargs): for k in audio_cache: audio = audio_cache[k]['audio'] - if resampler is not None: - audio = resampler(audio) + audio, _ = resample(audio, tts.output_sample_rate, args.output_sample_rate) if volume_adjust is not None: audio = volume_adjust(audio) @@ -1098,12 +1105,14 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non if basename in results and skip_existings: print(f"Skipping already parsed file: {basename}") - continue - - results[basename] = whisper_transcribe(file, language=language) + else: + results[basename] = whisper_transcribe(file, language=language) - # lazy copy waveform, sample_rate = torchaudio.load(file) + # resample to the input rate, since it'll get resampled for training anyways + # this should also "help" increase throughput a bit when filling the dataloaders + waveform, sample_rate = resample(waveform, sample_rate, tts.input_sample_rate if tts is not None else 22050) + torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate) with open(infile, 'w', encoding="utf-8") as f: @@ -1111,8 +1120,6 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non do_gc() - unload_whisper() - return f"Processed dataset to: {indir}" def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ): @@ -1154,9 +1161,11 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ): if trim_silence: sliced = torchaudio.functional.vad( sliced, sample_rate ) - - segments +=1 + + sliced, sample_rate = resample( sample_rate, 22050 ) torchaudio.save(f"{indir}/audio/{file}", sliced, sample_rate) + + segments +=1 messages.append(f"Sliced segments: {files} => {segments}.") return "\n".join(messages) @@ -1500,20 +1509,7 @@ def import_voices(files, saveAs=None, progress=None): if not voicefixer: load_voicefixer() - # resample to best bandwidth since voicefixer will do it anyways through librosa - if sample_rate != 44100: - print(f"Resampling imported voice sample: {path}") - resampler = torchaudio.transforms.Resample( - sample_rate, - 44100, - lowpass_filter_width=16, - rolloff=0.85, - resampling_method="kaiser_window", - beta=8.555504641634386, - ) - waveform = resampler(waveform) - sample_rate = 44100 - + waveform, sample_rate = resample(waveform, sample_rate, 44100) torchaudio.save(path, waveform, sample_rate) print(f"Running 'voicefixer' on voice sample: {path}")