my sanitizer actually did work, it was just batch sizes leading to problems when transcribing

This commit is contained in:
mrq 2023-03-23 04:41:56 +00:00
parent a6daf289bc
commit 444bcdaf62

View File

@ -1206,10 +1206,13 @@ def whisper_transcribe( file, language=None ):
device = "cuda" if get_device_name() == "cuda" else "cpu" device = "cuda" if get_device_name() == "cuda" else "cpu"
if whisper_vad: if whisper_vad:
"""
if args.whisper_batchsize > 1: if args.whisper_batchsize > 1:
result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe") result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe")
else: else:
result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad) result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad)
"""
result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad)
else: else:
result = whisper_model.transcribe(file) result = whisper_model.transcribe(file)
@ -1282,19 +1285,32 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress): for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
basename = os.path.basename(file) basename = os.path.basename(file)
modified = False
if basename in results and skip_existings: if basename in results and skip_existings:
print(f"Skipping already parsed file: {basename}") print(f"Skipping already parsed file: {basename}")
else: continue
try: try:
result = whisper_transcribe(file, language=language) result = whisper_transcribe(file, language=language)
modified = True
except Exception as e: except Exception as e:
print("Failed to transcribe:", file) print("Failed to transcribe:", file)
continue continue
results[basename] = result
""" results[basename] = result
waveform, sample_rate = torchaudio.load(file)
# resample to the input rate, since it'll get resampled for training anyways
# this should also "help" increase throughput a bit when filling the dataloaders
waveform, sample_rate = resample(waveform, sample_rate, TARGET_SAMPLE_RATE)
if waveform.shape[0] == 2:
waveform = waveform[:1]
torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, encoding="PCM_S", bits_per_sample=16)
with open(infile, 'w', encoding="utf-8") as f:
f.write(json.dumps(results, indent='\t'))
do_gc()
modified = False
for basename in results:
try: try:
sanitized = whisper_sanitize(results[basename]) sanitized = whisper_sanitize(results[basename])
if len(sanitized['segments']) > 0 and len(sanitized['segments']) != len(results[basename]['segments']): if len(sanitized['segments']) > 0 and len(sanitized['segments']) != len(results[basename]['segments']):
@ -1304,22 +1320,12 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
except Exception as e: except Exception as e:
print("Failed to sanitize:", basename, e) print("Failed to sanitize:", basename, e)
pass pass
"""
waveform, sample_rate = torchaudio.load(file)
# resample to the input rate, since it'll get resampled for training anyways
# this should also "help" increase throughput a bit when filling the dataloaders
waveform, sample_rate = resample(waveform, sample_rate, TARGET_SAMPLE_RATE)
if waveform.shape[0] == 2:
waveform = waveform[:1]
torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, encoding="PCM_S", bits_per_sample=16)
if modified: if modified:
os.rename(infile, infile.replace(".json", ".unsanitized.json"))
with open(infile, 'w', encoding="utf-8") as f: with open(infile, 'w', encoding="utf-8") as f:
f.write(json.dumps(results, indent='\t')) f.write(json.dumps(results, indent='\t'))
do_gc()
return f"Processed dataset to: {indir}" return f"Processed dataset to: {indir}"
def slice_waveform( waveform, sample_rate, start, end, trim ): def slice_waveform( waveform, sample_rate, start, end, trim ):