From 444bcdaf622acd895b7615c9344d9940061bd0f0 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 23 Mar 2023 04:41:56 +0000 Subject: [PATCH] my sanitizer actually did work, it was just batch sizes leading to problems when transcribing --- src/utils.py | 54 +++++++++++++++++++++++++++++----------------------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/src/utils.py b/src/utils.py index 668d696..e3277bf 100755 --- a/src/utils.py +++ b/src/utils.py @@ -1206,10 +1206,13 @@ def whisper_transcribe( file, language=None ): device = "cuda" if get_device_name() == "cuda" else "cpu" if whisper_vad: + """ if args.whisper_batchsize > 1: result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe") else: result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad) + """ + result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad) else: result = whisper_model.transcribe(file) @@ -1282,19 +1285,32 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress): basename = os.path.basename(file) - modified = False if basename in results and skip_existings: print(f"Skipping already parsed file: {basename}") - else: - try: - result = whisper_transcribe(file, language=language) - modified = True - except Exception as e: - print("Failed to transcribe:", file) - continue - results[basename] = result + continue - """ + try: + result = whisper_transcribe(file, language=language) + except Exception as e: + print("Failed to transcribe:", file) + continue + + results[basename] = result + waveform, sample_rate = torchaudio.load(file) + # resample to the input rate, since it'll get resampled for training anyways + # this should also "help" increase throughput a bit when filling the dataloaders + waveform, sample_rate = resample(waveform, sample_rate, TARGET_SAMPLE_RATE) + if waveform.shape[0] == 2: + waveform = waveform[:1] + torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, encoding="PCM_S", bits_per_sample=16) + + with open(infile, 'w', encoding="utf-8") as f: + f.write(json.dumps(results, indent='\t')) + + do_gc() + + modified = False + for basename in results: try: sanitized = whisper_sanitize(results[basename]) if len(sanitized['segments']) > 0 and len(sanitized['segments']) != len(results[basename]['segments']): @@ -1304,21 +1320,11 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non except Exception as e: print("Failed to sanitize:", basename, e) pass - """ - waveform, sample_rate = torchaudio.load(file) - # resample to the input rate, since it'll get resampled for training anyways - # this should also "help" increase throughput a bit when filling the dataloaders - waveform, sample_rate = resample(waveform, sample_rate, TARGET_SAMPLE_RATE) - if waveform.shape[0] == 2: - waveform = waveform[:1] - torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, encoding="PCM_S", bits_per_sample=16) - - if modified: - with open(infile, 'w', encoding="utf-8") as f: - f.write(json.dumps(results, indent='\t')) - - do_gc() + if modified: + os.rename(infile, infile.replace(".json", ".unsanitized.json")) + with open(infile, 'w', encoding="utf-8") as f: + f.write(json.dumps(results, indent='\t')) return f"Processed dataset to: {indir}"