From 050bcefd730911f64cc354425a2f2a9896a9fbde Mon Sep 17 00:00:00 2001
From: mrq <mrq@ecker.tech>
Date: Mon, 13 Mar 2023 01:20:55 +0000
Subject: [PATCH] resample to 22.5K when creating training inputs (to avoid
 redundant downsampling when loaded for training, even though most of my
 inputs are already at 22.5K), generalized resampler function to cache and
 reuse them, do not unload whisper when done transcribing since it gets
 unloaded anyways for any other non-transcription task

---
 src/utils.py | 66 ++++++++++++++++++++++++----------------------------
 1 file changed, 31 insertions(+), 35 deletions(-)

diff --git a/src/utils.py b/src/utils.py
index 466460c..508a7e6 100755
--- a/src/utils.py
+++ b/src/utils.py
@@ -49,6 +49,8 @@ GENERATE_SETTINGS_ARGS = None
 LEARNING_RATE_SCHEMES = {"Multistep": "MultiStepLR", "Cos. Annealing": "CosineAnnealingLR_Restart"}
 LEARNING_RATE_SCHEDULE = [ 9, 18, 25, 33 ]
 
+RESAMPLERS = {}
+
 args = None
 tts = None
 tts_loading = False
@@ -59,6 +61,23 @@ training_state = None
 
 current_voice = None
 
+def resample( waveform, input_rate, output_rate=44100 ):
+	if input_rate == output_rate:
+		return waveform, output_rate
+
+	key = f'{input_rate}:{output_rate}'
+	if not key in RESAMPLERS:
+		RESAMPLERS[key] = torchaudio.transforms.Resample(
+			input_rate,
+			output_rate,
+			lowpass_filter_width=16,
+			rolloff=0.85,
+			resampling_method="kaiser_window",
+			beta=8.555504641634386,
+		)
+
+	return RESAMPLERS[key]( waveform ), output_rate
+
 def generate(**kwargs):
 	parameters = {}
 	parameters.update(kwargs)
@@ -199,17 +218,6 @@ def generate(**kwargs):
 	os.makedirs(outdir, exist_ok=True)
 
 	audio_cache = {}
-	resample = None
-
-	if tts.output_sample_rate != args.output_sample_rate:
-		resampler = torchaudio.transforms.Resample(
-			tts.output_sample_rate,
-			args.output_sample_rate,
-			lowpass_filter_width=16,
-			rolloff=0.85,
-			resampling_method="kaiser_window",
-			beta=8.555504641634386,
-		)
 
 	volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None
 
@@ -343,8 +351,7 @@ def generate(**kwargs):
 	for k in audio_cache:
 		audio = audio_cache[k]['audio']
 
-		if resampler is not None:
-			audio = resampler(audio)
+		audio, _ = resample(audio, tts.output_sample_rate, args.output_sample_rate)
 		if volume_adjust is not None:
 			audio = volume_adjust(audio)
 
@@ -1098,12 +1105,14 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
 
 		if basename in results and skip_existings:
 			print(f"Skipping already parsed file: {basename}")
-			continue
+		else:
+			results[basename] = whisper_transcribe(file, language=language)
 
-		results[basename] = whisper_transcribe(file, language=language)
-
-		# lazy copy
 		waveform, sample_rate = torchaudio.load(file)
+		# resample to the input rate, since it'll get resampled for training anyways
+		# this should also "help" increase throughput a bit when filling the dataloaders
+		waveform, sample_rate = resample(waveform, sample_rate, tts.input_sample_rate if tts is not None else 22050)
+
 		torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate)
 
 		with open(infile, 'w', encoding="utf-8") as f:
@@ -1111,8 +1120,6 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
 
 		do_gc()
 
-	unload_whisper()
-
 	return f"Processed dataset to: {indir}"
 
 def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ):
@@ -1154,9 +1161,11 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ):
 
 			if trim_silence:
 				sliced = torchaudio.functional.vad( sliced, sample_rate )
-
-			segments +=1
+			
+			sliced, sample_rate = resample( sample_rate, 22050 )
 			torchaudio.save(f"{indir}/audio/{file}", sliced, sample_rate)
+			
+			segments +=1
 
 	messages.append(f"Sliced segments: {files} => {segments}.")
 	return "\n".join(messages)
@@ -1500,20 +1509,7 @@ def import_voices(files, saveAs=None, progress=None):
 				if not voicefixer:
 					load_voicefixer()
 
-				# resample to best bandwidth since voicefixer will do it anyways through librosa
-				if sample_rate != 44100:
-					print(f"Resampling imported voice sample: {path}")
-					resampler = torchaudio.transforms.Resample(
-						sample_rate,
-						44100,
-						lowpass_filter_width=16,
-						rolloff=0.85,
-						resampling_method="kaiser_window",
-						beta=8.555504641634386,
-					)
-					waveform = resampler(waveform)
-					sample_rate = 44100
-
+				waveform, sample_rate = resample(waveform, sample_rate, 44100)
 				torchaudio.save(path, waveform, sample_rate)
 
 				print(f"Running 'voicefixer' on voice sample: {path}")