forked from mrq/ai-voice-cloning
my sanitizer actually did work, it was just batch sizes leading to problems when transcribing
This commit is contained in:
parent
a6daf289bc
commit
444bcdaf62
38
src/utils.py
38
src/utils.py
|
@ -1206,10 +1206,13 @@ def whisper_transcribe( file, language=None ):
|
||||||
|
|
||||||
device = "cuda" if get_device_name() == "cuda" else "cpu"
|
device = "cuda" if get_device_name() == "cuda" else "cpu"
|
||||||
if whisper_vad:
|
if whisper_vad:
|
||||||
|
"""
|
||||||
if args.whisper_batchsize > 1:
|
if args.whisper_batchsize > 1:
|
||||||
result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe")
|
result = whisperx.transcribe_with_vad_parallel(whisper_model, file, whisper_vad, batch_size=args.whisper_batchsize, language=language, task="transcribe")
|
||||||
else:
|
else:
|
||||||
result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad)
|
result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad)
|
||||||
|
"""
|
||||||
|
result = whisperx.transcribe_with_vad(whisper_model, file, whisper_vad)
|
||||||
else:
|
else:
|
||||||
result = whisper_model.transcribe(file)
|
result = whisper_model.transcribe(file)
|
||||||
|
|
||||||
|
@ -1282,19 +1285,32 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
|
||||||
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
|
for file in enumerate_progress(files, desc="Iterating through voice files", progress=progress):
|
||||||
basename = os.path.basename(file)
|
basename = os.path.basename(file)
|
||||||
|
|
||||||
modified = False
|
|
||||||
if basename in results and skip_existings:
|
if basename in results and skip_existings:
|
||||||
print(f"Skipping already parsed file: {basename}")
|
print(f"Skipping already parsed file: {basename}")
|
||||||
else:
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = whisper_transcribe(file, language=language)
|
result = whisper_transcribe(file, language=language)
|
||||||
modified = True
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Failed to transcribe:", file)
|
print("Failed to transcribe:", file)
|
||||||
continue
|
continue
|
||||||
results[basename] = result
|
|
||||||
|
|
||||||
"""
|
results[basename] = result
|
||||||
|
waveform, sample_rate = torchaudio.load(file)
|
||||||
|
# resample to the input rate, since it'll get resampled for training anyways
|
||||||
|
# this should also "help" increase throughput a bit when filling the dataloaders
|
||||||
|
waveform, sample_rate = resample(waveform, sample_rate, TARGET_SAMPLE_RATE)
|
||||||
|
if waveform.shape[0] == 2:
|
||||||
|
waveform = waveform[:1]
|
||||||
|
torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, encoding="PCM_S", bits_per_sample=16)
|
||||||
|
|
||||||
|
with open(infile, 'w', encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(results, indent='\t'))
|
||||||
|
|
||||||
|
do_gc()
|
||||||
|
|
||||||
|
modified = False
|
||||||
|
for basename in results:
|
||||||
try:
|
try:
|
||||||
sanitized = whisper_sanitize(results[basename])
|
sanitized = whisper_sanitize(results[basename])
|
||||||
if len(sanitized['segments']) > 0 and len(sanitized['segments']) != len(results[basename]['segments']):
|
if len(sanitized['segments']) > 0 and len(sanitized['segments']) != len(results[basename]['segments']):
|
||||||
|
@ -1304,22 +1320,12 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Failed to sanitize:", basename, e)
|
print("Failed to sanitize:", basename, e)
|
||||||
pass
|
pass
|
||||||
"""
|
|
||||||
|
|
||||||
waveform, sample_rate = torchaudio.load(file)
|
|
||||||
# resample to the input rate, since it'll get resampled for training anyways
|
|
||||||
# this should also "help" increase throughput a bit when filling the dataloaders
|
|
||||||
waveform, sample_rate = resample(waveform, sample_rate, TARGET_SAMPLE_RATE)
|
|
||||||
if waveform.shape[0] == 2:
|
|
||||||
waveform = waveform[:1]
|
|
||||||
torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate, encoding="PCM_S", bits_per_sample=16)
|
|
||||||
|
|
||||||
if modified:
|
if modified:
|
||||||
|
os.rename(infile, infile.replace(".json", ".unsanitized.json"))
|
||||||
with open(infile, 'w', encoding="utf-8") as f:
|
with open(infile, 'w', encoding="utf-8") as f:
|
||||||
f.write(json.dumps(results, indent='\t'))
|
f.write(json.dumps(results, indent='\t'))
|
||||||
|
|
||||||
do_gc()
|
|
||||||
|
|
||||||
return f"Processed dataset to: {indir}"
|
return f"Processed dataset to: {indir}"
|
||||||
|
|
||||||
def slice_waveform( waveform, sample_rate, start, end, trim ):
|
def slice_waveform( waveform, sample_rate, start, end, trim ):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user