ugh
This commit is contained in:
parent
13605f980c
commit
aa5bdafb06
22
src/utils.py
22
src/utils.py
|
@ -1131,6 +1131,24 @@ def convert_to_halfp():
|
||||||
torch.save(model, outfile)
|
torch.save(model, outfile)
|
||||||
print(f'Converted model to half precision: {outfile}')
|
print(f'Converted model to half precision: {outfile}')
|
||||||
|
|
||||||
|
|
||||||
|
# collapses short segments into the previous segment
|
||||||
|
def whisper_sanitize( results ):
|
||||||
|
sanitized = results
|
||||||
|
sanitized['segments'] = []
|
||||||
|
|
||||||
|
for segment in results['segments']:
|
||||||
|
length = segment['end'] - segment['start']
|
||||||
|
if length >= MIN_TRAINING_DURATION or len(sanitized['segments']) == 0:
|
||||||
|
sanitized['segments'].append(segment)
|
||||||
|
continue
|
||||||
|
|
||||||
|
last_segment = sanitized['segments'][-1]
|
||||||
|
last_segment['text'] += segment['text']
|
||||||
|
last_segment['end'] = segment['end']
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
def whisper_transcribe( file, language=None ):
|
def whisper_transcribe( file, language=None ):
|
||||||
# shouldn't happen, but it's for safety
|
# shouldn't happen, but it's for safety
|
||||||
global whisper_model
|
global whisper_model
|
||||||
|
@ -1150,7 +1168,7 @@ def whisper_transcribe( file, language=None ):
|
||||||
segments = whisper_model.extract_text_and_timestamps( res )
|
segments = whisper_model.extract_text_and_timestamps( res )
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
'text': []
|
'text': [],
|
||||||
'segments': []
|
'segments': []
|
||||||
}
|
}
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
|
@ -1248,6 +1266,8 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
|
||||||
result = whisper_transcribe(file, language=language)
|
result = whisper_transcribe(file, language=language)
|
||||||
results[basename] = result
|
results[basename] = result
|
||||||
|
|
||||||
|
# results[basename] = whisper_sanitize(results[basename])
|
||||||
|
|
||||||
waveform, sample_rate = torchaudio.load(file)
|
waveform, sample_rate = torchaudio.load(file)
|
||||||
# resample to the input rate, since it'll get resampled for training anyways
|
# resample to the input rate, since it'll get resampled for training anyways
|
||||||
# this should also "help" increase throughput a bit when filling the dataloaders
|
# this should also "help" increase throughput a bit when filling the dataloaders
|
||||||
|
|
Loading…
Reference in New Issue
Block a user