diff --git a/src/utils.py b/src/utils.py index 466460c..508a7e6 100755 --- a/src/utils.py +++ b/src/utils.py @@ -49,6 +49,8 @@ GENERATE_SETTINGS_ARGS = None LEARNING_RATE_SCHEMES = {"Multistep": "MultiStepLR", "Cos. Annealing": "CosineAnnealingLR_Restart"} LEARNING_RATE_SCHEDULE = [ 9, 18, 25, 33 ] +RESAMPLERS = {} + args = None tts = None tts_loading = False @@ -59,6 +61,23 @@ training_state = None current_voice = None +def resample( waveform, input_rate, output_rate=44100 ): + if input_rate == output_rate: + return waveform, output_rate + + key = f'{input_rate}:{output_rate}' + if not key in RESAMPLERS: + RESAMPLERS[key] = torchaudio.transforms.Resample( + input_rate, + output_rate, + lowpass_filter_width=16, + rolloff=0.85, + resampling_method="kaiser_window", + beta=8.555504641634386, + ) + + return RESAMPLERS[key]( waveform ), output_rate + def generate(**kwargs): parameters = {} parameters.update(kwargs) @@ -199,17 +218,6 @@ def generate(**kwargs): os.makedirs(outdir, exist_ok=True) audio_cache = {} - resample = None - - if tts.output_sample_rate != args.output_sample_rate: - resampler = torchaudio.transforms.Resample( - tts.output_sample_rate, - args.output_sample_rate, - lowpass_filter_width=16, - rolloff=0.85, - resampling_method="kaiser_window", - beta=8.555504641634386, - ) volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None @@ -343,8 +351,7 @@ def generate(**kwargs): for k in audio_cache: audio = audio_cache[k]['audio'] - if resampler is not None: - audio = resampler(audio) + audio, _ = resample(audio, tts.output_sample_rate, args.output_sample_rate) if volume_adjust is not None: audio = volume_adjust(audio) @@ -1098,12 +1105,14 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non if basename in results and skip_existings: print(f"Skipping already parsed file: {basename}") - continue - - results[basename] = whisper_transcribe(file, language=language) + else: + results[basename] = whisper_transcribe(file, language=language) - # lazy copy waveform, sample_rate = torchaudio.load(file) + # resample to the input rate, since it'll get resampled for training anyways + # this should also "help" increase throughput a bit when filling the dataloaders + waveform, sample_rate = resample(waveform, sample_rate, tts.input_sample_rate if tts is not None else 22050) + torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate) with open(infile, 'w', encoding="utf-8") as f: @@ -1111,8 +1120,6 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non do_gc() - unload_whisper() - return f"Processed dataset to: {indir}" def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ): @@ -1154,9 +1161,11 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ): if trim_silence: sliced = torchaudio.functional.vad( sliced, sample_rate ) - - segments +=1 + + sliced, sample_rate = resample( sample_rate, 22050 ) torchaudio.save(f"{indir}/audio/{file}", sliced, sample_rate) + + segments +=1 messages.append(f"Sliced segments: {files} => {segments}.") return "\n".join(messages) @@ -1500,20 +1509,7 @@ def import_voices(files, saveAs=None, progress=None): if not voicefixer: load_voicefixer() - # resample to best bandwidth since voicefixer will do it anyways through librosa - if sample_rate != 44100: - print(f"Resampling imported voice sample: {path}") - resampler = torchaudio.transforms.Resample( - sample_rate, - 44100, - lowpass_filter_width=16, - rolloff=0.85, - resampling_method="kaiser_window", - beta=8.555504641634386, - ) - waveform = resampler(waveform) - sample_rate = 44100 - + waveform, sample_rate = resample(waveform, sample_rate, 44100) torchaudio.save(path, waveform, sample_rate) print(f"Running 'voicefixer' on voice sample: {path}")