resample to 22.5K when creating training inputs (to avoid redundant downsampling when loaded for training, even though most of my inputs are already at 22.5K), generalized resampler function to cache and reuse them, do not unload whisper when done transcribing since it gets unloaded anyways for any other non-transcription task

This commit is contained in:
mrq 2023-03-13 01:20:55 +00:00
parent 7c9c0dc584
commit 050bcefd73

View File

@ -49,6 +49,8 @@ GENERATE_SETTINGS_ARGS = None
LEARNING_RATE_SCHEMES = {"Multistep": "MultiStepLR", "Cos. Annealing": "CosineAnnealingLR_Restart"} LEARNING_RATE_SCHEMES = {"Multistep": "MultiStepLR", "Cos. Annealing": "CosineAnnealingLR_Restart"}
LEARNING_RATE_SCHEDULE = [ 9, 18, 25, 33 ] LEARNING_RATE_SCHEDULE = [ 9, 18, 25, 33 ]
RESAMPLERS = {}
args = None args = None
tts = None tts = None
tts_loading = False tts_loading = False
@ -59,6 +61,23 @@ training_state = None
current_voice = None current_voice = None
def resample( waveform, input_rate, output_rate=44100 ):
if input_rate == output_rate:
return waveform, output_rate
key = f'{input_rate}:{output_rate}'
if not key in RESAMPLERS:
RESAMPLERS[key] = torchaudio.transforms.Resample(
input_rate,
output_rate,
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
beta=8.555504641634386,
)
return RESAMPLERS[key]( waveform ), output_rate
def generate(**kwargs): def generate(**kwargs):
parameters = {} parameters = {}
parameters.update(kwargs) parameters.update(kwargs)
@ -199,17 +218,6 @@ def generate(**kwargs):
os.makedirs(outdir, exist_ok=True) os.makedirs(outdir, exist_ok=True)
audio_cache = {} audio_cache = {}
resample = None
if tts.output_sample_rate != args.output_sample_rate:
resampler = torchaudio.transforms.Resample(
tts.output_sample_rate,
args.output_sample_rate,
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
beta=8.555504641634386,
)
volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None
@ -343,8 +351,7 @@ def generate(**kwargs):
for k in audio_cache: for k in audio_cache:
audio = audio_cache[k]['audio'] audio = audio_cache[k]['audio']
if resampler is not None: audio, _ = resample(audio, tts.output_sample_rate, args.output_sample_rate)
audio = resampler(audio)
if volume_adjust is not None: if volume_adjust is not None:
audio = volume_adjust(audio) audio = volume_adjust(audio)
@ -1098,12 +1105,14 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
if basename in results and skip_existings: if basename in results and skip_existings:
print(f"Skipping already parsed file: {basename}") print(f"Skipping already parsed file: {basename}")
continue else:
results[basename] = whisper_transcribe(file, language=language)
results[basename] = whisper_transcribe(file, language=language)
# lazy copy
waveform, sample_rate = torchaudio.load(file) waveform, sample_rate = torchaudio.load(file)
# resample to the input rate, since it'll get resampled for training anyways
# this should also "help" increase throughput a bit when filling the dataloaders
waveform, sample_rate = resample(waveform, sample_rate, tts.input_sample_rate if tts is not None else 22050)
torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate) torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate)
with open(infile, 'w', encoding="utf-8") as f: with open(infile, 'w', encoding="utf-8") as f:
@ -1111,8 +1120,6 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
do_gc() do_gc()
unload_whisper()
return f"Processed dataset to: {indir}" return f"Processed dataset to: {indir}"
def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ): def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ):
@ -1155,9 +1162,11 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ):
if trim_silence: if trim_silence:
sliced = torchaudio.functional.vad( sliced, sample_rate ) sliced = torchaudio.functional.vad( sliced, sample_rate )
segments +=1 sliced, sample_rate = resample( sample_rate, 22050 )
torchaudio.save(f"{indir}/audio/{file}", sliced, sample_rate) torchaudio.save(f"{indir}/audio/{file}", sliced, sample_rate)
segments +=1
messages.append(f"Sliced segments: {files} => {segments}.") messages.append(f"Sliced segments: {files} => {segments}.")
return "\n".join(messages) return "\n".join(messages)
@ -1500,20 +1509,7 @@ def import_voices(files, saveAs=None, progress=None):
if not voicefixer: if not voicefixer:
load_voicefixer() load_voicefixer()
# resample to best bandwidth since voicefixer will do it anyways through librosa waveform, sample_rate = resample(waveform, sample_rate, 44100)
if sample_rate != 44100:
print(f"Resampling imported voice sample: {path}")
resampler = torchaudio.transforms.Resample(
sample_rate,
44100,
lowpass_filter_width=16,
rolloff=0.85,
resampling_method="kaiser_window",
beta=8.555504641634386,
)
waveform = resampler(waveform)
sample_rate = 44100
torchaudio.save(path, waveform, sample_rate) torchaudio.save(path, waveform, sample_rate)
print(f"Running 'voicefixer' on voice sample: {path}") print(f"Running 'voicefixer' on voice sample: {path}")