resample to 22.5K when creating training inputs (to avoid redundant downsampling when loaded for training, even though most of my inputs are already at 22.5K), generalized resampler function to cache and reuse them, do not unload whisper when done transcribing since it gets unloaded anyways for any other non-transcription task

This commit is contained in:
mrq 2023-03-13 01:20:55 +00:00
parent 7c9c0dc584
commit 050bcefd73

View File

@ -49,6 +49,8 @@ GENERATE_SETTINGS_ARGS = None
LEARNING_RATE_SCHEMES = {"Multistep": "MultiStepLR", "Cos. Annealing": "CosineAnnealingLR_Restart"}
LEARNING_RATE_SCHEDULE = [ 9, 18, 25, 33 ]
RESAMPLERS = {}
args = None
tts = None
tts_loading = False
@ -59,6 +61,23 @@ training_state = None
current_voice = None
def resample( waveform, input_rate, output_rate=44100 ):
if input_rate == output_rate:
return waveform, output_rate
key = f'{input_rate}:{output_rate}'
if not key in RESAMPLERS:
RESAMPLERS[key] = torchaudio.transforms.Resample(
input_rate,
output_rate,
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
beta=8.555504641634386,
)
return RESAMPLERS[key]( waveform ), output_rate
def generate(**kwargs):
parameters = {}
parameters.update(kwargs)
@ -199,17 +218,6 @@ def generate(**kwargs):
os.makedirs(outdir, exist_ok=True)
audio_cache = {}
resample = None
if tts.output_sample_rate != args.output_sample_rate:
resampler = torchaudio.transforms.Resample(
tts.output_sample_rate,
args.output_sample_rate,
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
beta=8.555504641634386,
)
volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None
@ -343,8 +351,7 @@ def generate(**kwargs):
for k in audio_cache:
audio = audio_cache[k]['audio']
if resampler is not None:
audio = resampler(audio)
audio, _ = resample(audio, tts.output_sample_rate, args.output_sample_rate)
if volume_adjust is not None:
audio = volume_adjust(audio)
@ -1098,12 +1105,14 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
if basename in results and skip_existings:
print(f"Skipping already parsed file: {basename}")
continue
else:
results[basename] = whisper_transcribe(file, language=language)
results[basename] = whisper_transcribe(file, language=language)
# lazy copy
waveform, sample_rate = torchaudio.load(file)
# resample to the input rate, since it'll get resampled for training anyways
# this should also "help" increase throughput a bit when filling the dataloaders
waveform, sample_rate = resample(waveform, sample_rate, tts.input_sample_rate if tts is not None else 22050)
torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate)
with open(infile, 'w', encoding="utf-8") as f:
@ -1111,8 +1120,6 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
do_gc()
unload_whisper()
return f"Processed dataset to: {indir}"
def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ):
@ -1155,9 +1162,11 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ):
if trim_silence:
sliced = torchaudio.functional.vad( sliced, sample_rate )
segments +=1
sliced, sample_rate = resample( sample_rate, 22050 )
torchaudio.save(f"{indir}/audio/{file}", sliced, sample_rate)
segments +=1
messages.append(f"Sliced segments: {files} => {segments}.")
return "\n".join(messages)
@ -1500,20 +1509,7 @@ def import_voices(files, saveAs=None, progress=None):
if not voicefixer:
load_voicefixer()
# resample to best bandwidth since voicefixer will do it anyways through librosa
if sample_rate != 44100:
print(f"Resampling imported voice sample: {path}")
resampler = torchaudio.transforms.Resample(
sample_rate,
44100,
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
beta=8.555504641634386,
)
waveform = resampler(waveform)
sample_rate = 44100
waveform, sample_rate = resample(waveform, sample_rate, 44100)
torchaudio.save(path, waveform, sample_rate)
print(f"Running 'voicefixer' on voice sample: {path}")