This commit is contained in:
mrq 2023-03-22 20:26:28 +00:00
parent 13605f980c
commit aa5bdafb06

View File

@ -1131,6 +1131,24 @@ def convert_to_halfp():
torch.save(model, outfile) torch.save(model, outfile)
print(f'Converted model to half precision: {outfile}') print(f'Converted model to half precision: {outfile}')
# collapses short segments into the previous segment
def whisper_sanitize( results ):
sanitized = results
sanitized['segments'] = []
for segment in results['segments']:
length = segment['end'] - segment['start']
if length >= MIN_TRAINING_DURATION or len(sanitized['segments']) == 0:
sanitized['segments'].append(segment)
continue
last_segment = sanitized['segments'][-1]
last_segment['text'] += segment['text']
last_segment['end'] = segment['end']
return sanitized
def whisper_transcribe( file, language=None ): def whisper_transcribe( file, language=None ):
# shouldn't happen, but it's for safety # shouldn't happen, but it's for safety
global whisper_model global whisper_model
@ -1150,7 +1168,7 @@ def whisper_transcribe( file, language=None ):
segments = whisper_model.extract_text_and_timestamps( res ) segments = whisper_model.extract_text_and_timestamps( res )
result = { result = {
'text': [] 'text': [],
'segments': [] 'segments': []
} }
for segment in segments: for segment in segments:
@ -1248,6 +1266,8 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
result = whisper_transcribe(file, language=language) result = whisper_transcribe(file, language=language)
results[basename] = result results[basename] = result
# results[basename] = whisper_sanitize(results[basename])
waveform, sample_rate = torchaudio.load(file) waveform, sample_rate = torchaudio.load(file)
# resample to the input rate, since it'll get resampled for training anyways # resample to the input rate, since it'll get resampled for training anyways
# this should also "help" increase throughput a bit when filling the dataloaders # this should also "help" increase throughput a bit when filling the dataloaders