resample to 22.5K when creating training inputs (to avoid redundant downsampling when loaded for training, even though most of my inputs are already at 22.5K), generalized resampler function to cache and reuse them, do not unload whisper when done transcribing since it gets unloaded anyways for any other non-transcription task
This commit is contained in:
parent
7c9c0dc584
commit
050bcefd73
66
src/utils.py
66
src/utils.py
|
@ -49,6 +49,8 @@ GENERATE_SETTINGS_ARGS = None
|
|||
LEARNING_RATE_SCHEMES = {"Multistep": "MultiStepLR", "Cos. Annealing": "CosineAnnealingLR_Restart"}
|
||||
LEARNING_RATE_SCHEDULE = [ 9, 18, 25, 33 ]
|
||||
|
||||
RESAMPLERS = {}
|
||||
|
||||
args = None
|
||||
tts = None
|
||||
tts_loading = False
|
||||
|
@ -59,6 +61,23 @@ training_state = None
|
|||
|
||||
current_voice = None
|
||||
|
||||
def resample( waveform, input_rate, output_rate=44100 ):
|
||||
if input_rate == output_rate:
|
||||
return waveform, output_rate
|
||||
|
||||
key = f'{input_rate}:{output_rate}'
|
||||
if not key in RESAMPLERS:
|
||||
RESAMPLERS[key] = torchaudio.transforms.Resample(
|
||||
input_rate,
|
||||
output_rate,
|
||||
lowpass_filter_width=16,
|
||||
rolloff=0.85,
|
||||
resampling_method="kaiser_window",
|
||||
beta=8.555504641634386,
|
||||
)
|
||||
|
||||
return RESAMPLERS[key]( waveform ), output_rate
|
||||
|
||||
def generate(**kwargs):
|
||||
parameters = {}
|
||||
parameters.update(kwargs)
|
||||
|
@ -199,17 +218,6 @@ def generate(**kwargs):
|
|||
os.makedirs(outdir, exist_ok=True)
|
||||
|
||||
audio_cache = {}
|
||||
resample = None
|
||||
|
||||
if tts.output_sample_rate != args.output_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(
|
||||
tts.output_sample_rate,
|
||||
args.output_sample_rate,
|
||||
lowpass_filter_width=16,
|
||||
rolloff=0.85,
|
||||
resampling_method="kaiser_window",
|
||||
beta=8.555504641634386,
|
||||
)
|
||||
|
||||
volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None
|
||||
|
||||
|
@ -343,8 +351,7 @@ def generate(**kwargs):
|
|||
for k in audio_cache:
|
||||
audio = audio_cache[k]['audio']
|
||||
|
||||
if resampler is not None:
|
||||
audio = resampler(audio)
|
||||
audio, _ = resample(audio, tts.output_sample_rate, args.output_sample_rate)
|
||||
if volume_adjust is not None:
|
||||
audio = volume_adjust(audio)
|
||||
|
||||
|
@ -1098,12 +1105,14 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
|
|||
|
||||
if basename in results and skip_existings:
|
||||
print(f"Skipping already parsed file: {basename}")
|
||||
continue
|
||||
else:
|
||||
results[basename] = whisper_transcribe(file, language=language)
|
||||
|
||||
results[basename] = whisper_transcribe(file, language=language)
|
||||
|
||||
# lazy copy
|
||||
waveform, sample_rate = torchaudio.load(file)
|
||||
# resample to the input rate, since it'll get resampled for training anyways
|
||||
# this should also "help" increase throughput a bit when filling the dataloaders
|
||||
waveform, sample_rate = resample(waveform, sample_rate, tts.input_sample_rate if tts is not None else 22050)
|
||||
|
||||
torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate)
|
||||
|
||||
with open(infile, 'w', encoding="utf-8") as f:
|
||||
|
@ -1111,8 +1120,6 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
|
|||
|
||||
do_gc()
|
||||
|
||||
unload_whisper()
|
||||
|
||||
return f"Processed dataset to: {indir}"
|
||||
|
||||
def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ):
|
||||
|
@ -1154,9 +1161,11 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ):
|
|||
|
||||
if trim_silence:
|
||||
sliced = torchaudio.functional.vad( sliced, sample_rate )
|
||||
|
||||
segments +=1
|
||||
|
||||
sliced, sample_rate = resample( sample_rate, 22050 )
|
||||
torchaudio.save(f"{indir}/audio/{file}", sliced, sample_rate)
|
||||
|
||||
segments +=1
|
||||
|
||||
messages.append(f"Sliced segments: {files} => {segments}.")
|
||||
return "\n".join(messages)
|
||||
|
@ -1500,20 +1509,7 @@ def import_voices(files, saveAs=None, progress=None):
|
|||
if not voicefixer:
|
||||
load_voicefixer()
|
||||
|
||||
# resample to best bandwidth since voicefixer will do it anyways through librosa
|
||||
if sample_rate != 44100:
|
||||
print(f"Resampling imported voice sample: {path}")
|
||||
resampler = torchaudio.transforms.Resample(
|
||||
sample_rate,
|
||||
44100,
|
||||
lowpass_filter_width=16,
|
||||
rolloff=0.85,
|
||||
resampling_method="kaiser_window",
|
||||
beta=8.555504641634386,
|
||||
)
|
||||
waveform = resampler(waveform)
|
||||
sample_rate = 44100
|
||||
|
||||
waveform, sample_rate = resample(waveform, sample_rate, 44100)
|
||||
torchaudio.save(path, waveform, sample_rate)
|
||||
|
||||
print(f"Running 'voicefixer' on voice sample: {path}")
|
||||
|
|
Loading…
Reference in New Issue
Block a user