From aa5bdafb06564e507974e87f6f1365cd3d38c9a6 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 22 Mar 2023 20:26:28 +0000 Subject: [PATCH] ugh --- src/utils.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/src/utils.py b/src/utils.py index 5ccb93d..389d2af 100755 --- a/src/utils.py +++ b/src/utils.py @@ -1131,6 +1131,24 @@ def convert_to_halfp(): torch.save(model, outfile) print(f'Converted model to half precision: {outfile}') + +# collapses short segments into the previous segment +def whisper_sanitize( results ): + sanitized = results + sanitized['segments'] = [] + + for segment in results['segments']: + length = segment['end'] - segment['start'] + if length >= MIN_TRAINING_DURATION or len(sanitized['segments']) == 0: + sanitized['segments'].append(segment) + continue + + last_segment = sanitized['segments'][-1] + last_segment['text'] += segment['text'] + last_segment['end'] = segment['end'] + + return sanitized + def whisper_transcribe( file, language=None ): # shouldn't happen, but it's for safety global whisper_model @@ -1150,7 +1168,7 @@ def whisper_transcribe( file, language=None ): segments = whisper_model.extract_text_and_timestamps( res ) result = { - 'text': [] + 'text': [], 'segments': [] } for segment in segments: @@ -1248,6 +1266,8 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non result = whisper_transcribe(file, language=language) results[basename] = result + # results[basename] = whisper_sanitize(results[basename]) + waveform, sample_rate = torchaudio.load(file) # resample to the input rate, since it'll get resampled for training anyways # this should also "help" increase throughput a bit when filling the dataloaders