diff --git a/src/utils.py b/src/utils.py index 668d696..e3277bf 100755 --- a/src/utils.py +++ b/src/utils.py @@ -1206,10 +1206,13 @@ def whisper_transcribe( file, language=None ): device = "cuda" if get_device_name() == "cuda" else "cpu" if whisper_vad: + """ if args.whisper_batchsize > 1: result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe") else: result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad) + """ + result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad) else: result = whisper_model.transcribe(file) @@ -1282,30 +1285,17 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress): basename = os.path.basename(file) - modified = False if basename in results and skip_existings: print(f"Skipping already parsed file: {basename}") - else: - try: - result = whisper_transcribe(file, language=language) - modified = True - except Exception as e: - print("Failed to transcribe:", file) - continue - results[basename] = result + continue - """ try: - sanitized = whisper_sanitize(results[basename]) - if len(sanitized['segments']) > 0 and len(sanitized['segments']) != len(results[basename]['segments']): - results[basename] = sanitized - modified = True - print("Segments sanizited: ", basename) + result = whisper_transcribe(file, language=language) except Exception as e: - print("Failed to sanitize:", basename, e) - pass - """ + print("Failed to transcribe:", file) + continue + results[basename] = result waveform, sample_rate = torchaudio.load(file) # resample to the input rate, since it'll get resampled for training anyways # this should also "help" increase throughput a bit when filling the dataloaders @@ -1314,12 +1304,28 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non waveform = waveform[:1] torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, encoding="PCM_S", bits_per_sample=16) - if modified: - with open(infile, 'w', encoding="utf-8") as f: - f.write(json.dumps(results, indent='\t')) + with open(infile, 'w', encoding="utf-8") as f: + f.write(json.dumps(results, indent='\t')) do_gc() + modified = False + for basename in results: + try: + sanitized = whisper_sanitize(results[basename]) + if len(sanitized['segments']) > 0 and len(sanitized['segments']) != len(results[basename]['segments']): + results[basename] = sanitized + modified = True + print("Segments sanizited: ", basename) + except Exception as e: + print("Failed to sanitize:", basename, e) + pass + + if modified: + os.rename(infile, infile.replace(".json", ".unsanitized.json")) + with open(infile, 'w', encoding="utf-8") as f: + f.write(json.dumps(results, indent='\t')) + return f"Processed dataset to: {indir}" def slice_waveform( waveform, sample_rate, start, end, trim ):