From 239c98485040aa7103b0b354d6aa23eaa5ae68f8 Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 12 Mar 2023 23:39:00 +0000 Subject: [PATCH] move validating audio to creating the text files instead, consider audio longer than 11 seconds invalid, consider text lengths over 200 invalid --- src/utils.py | 87 ++++++++++++++++++++++++++++++++++------------------ 1 file changed, 57 insertions(+), 30 deletions(-) diff --git a/src/utils.py b/src/utils.py index 0b2aa77..809d3ac 100755 --- a/src/utils.py +++ b/src/utils.py @@ -1065,11 +1065,18 @@ def whisper_transcribe( file, language=None ): def validate_waveform( waveform, sample_rate ): if not torch.any(waveform < 0): - return False + return "Waveform is empty" - if waveform.shape[-1] < (.6 * sample_rate): - return False - return True + num_channels, num_frames = waveform.shape + duration = num_channels * num_frames / sample_rate + + if duration < 0.6: + return "Duration too short ({:.3f} < 0.6s)".format(duration) + + if duration > 11: + return "Duration too long (11s < {:.3f})".format(duration) + + return def transcribe_dataset( voice, language=None, skip_existings=False, progress=None ): unload_tts() @@ -1100,8 +1107,8 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non results[basename] = whisper_transcribe(file, language=language) # lazy copy - waveform, sampling_rate = torchaudio.load(file) - torchaudio.save(f"{indir}/audio/{basename}", waveform, sampling_rate) + waveform, sample_rate = torchaudio.load(file) + torchaudio.save(f"{indir}/audio/{basename}", waveform, sample_rate) with open(infile, 'w', encoding="utf-8") as f: f.write(json.dumps(results, indent='\t')) @@ -1115,6 +1122,7 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ): indir = f'./training/{voice}/' infile = f'{indir}/whisper.json' + messages = [] if not os.path.exists(infile): raise Exception(f"Missing dataset: {infile}") @@ -1124,15 +1132,21 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ): files = 0 segments = 0 for filename in results: - files += 1 + path = f'./voices/{voice}/{filename}' + if not os.path.exists(path): + path = f'./training/{voice}/{filename}' + if not os.path.exists(path): + messages.append(f"Missing source audio: {filename}") + continue + + files += 1 result = results[filename] - waveform, sampling_rate = torchaudio.load(f'./voices/{voice}/{filename}') + waveform, sample_rate = torchaudio.load(path) - for segment in result['segments']: # enumerate_progress(result['segments'], desc="Segmenting voice file", progress=progress): - segments +=1 - start = int((segment['start'] + start_offset) * sampling_rate) - end = int((segment['end'] + end_offset) * sampling_rate) + for segment in result['segments']: + start = int((segment['start'] + start_offset) * sample_rate) + end = int((segment['end'] + end_offset) * sample_rate) if start < 0: start = 0 @@ -1142,20 +1156,19 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ): sliced = waveform[:, start:end] file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") - if not validate_waveform( sliced, sampling_rate ): - print(f"Invalid waveform segment ({segment['start']}:{segment['end']}): {file}, skipping...") - continue - if trim_silence: - sliced = torchaudio.functional.vad( sliced, sampling_rate ) + sliced = torchaudio.functional.vad( sliced, sample_rate ) - torchaudio.save(f"{indir}/audio/{file}", sliced, sampling_rate) + segments +=1 + torchaudio.save(f"{indir}/audio/{file}", sliced, sample_rate) - return f"Sliced segments: {files} => {segments}." + messages.append(f"Sliced segments: {files} => {segments}.") + return "\n".join(messages) def prepare_dataset( voice, use_segments, text_length, audio_length ): indir = f'./training/{voice}/' infile = f'{indir}/whisper.json' + messages = [] if not os.path.exists(infile): raise Exception(f"Missing dataset: {infile}") @@ -1171,16 +1184,27 @@ def prepare_dataset( voice, use_segments, text_length, audio_length ): result = results[filename] segments = result['segments'] if use_segments else [{'text': result['text']}] for segment in segments: - text = segment['text'].strip() file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") if use_segments else filename path = f'{indir}/audio/{file}' if not os.path.exists(path): + messages.append(f"Missing source audio: {file}") + continue + + text = segment['text'].strip() + if len(text) > 200: + messages.append(f"[{file}] Text length too long (200 < {len(text)}), skipping...") + + waveform, sample_rate = torchaudio.load(path) + num_channels, num_frames = waveform.shape + duration = num_channels * num_frames / sample_rate + + error = validate_waveform( waveform, sample_rate ) + if error: + messages.append(f"[{file}]: {error}, skipping...") continue culled = len(text) < text_length if not culled and audio_length > 0: - metadata = torchaudio.info(path) - duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate culled = duration < audio_length lines['training' if not culled else 'validation'].append(f'audio/{file}|{text}') @@ -1194,8 +1218,8 @@ def prepare_dataset( voice, use_segments, text_length, audio_length ): with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f: f.write(validation_joined) - msg = f"Prepared {len(lines['training'])} lines (validation: {len(lines['validation'])}).\n{training_joined}\n\n{validation_joined}" - return msg + messages.append(f"Prepared {len(lines['training'])} lines (validation: {len(lines['validation'])}).\n{training_joined}\n\n{validation_joined}") + return "\n".join(messages) def calc_iterations( epochs, lines, batch_size ): iterations = int(epochs * lines / float(batch_size)) @@ -1213,6 +1237,9 @@ def optimize_training_settings( **kwargs ): with open(dataset_path, 'r', encoding="utf-8") as f: lines = len(f.readlines()) + if lines == 0: + raise Exception("Empty dataset.") + if settings['batch_size'] > lines: settings['batch_size'] = lines messages.append(f"Batch size is larger than your dataset, clamping batch size to: {settings['batch_size']}") @@ -1471,17 +1498,17 @@ def import_voices(files, saveAs=None, progress=None): path = f"{outdir}/{os.path.basename(filename)}" print(f"Importing voice to {path}") - waveform, sampling_rate = torchaudio.load(filename) + waveform, sample_rate = torchaudio.load(filename) if args.voice_fixer: if not voicefixer: load_voicefixer() # resample to best bandwidth since voicefixer will do it anyways through librosa - if sampling_rate != 44100: + if sample_rate != 44100: print(f"Resampling imported voice sample: {path}") resampler = torchaudio.transforms.Resample( - sampling_rate, + sample_rate, 44100, lowpass_filter_width=16, rolloff=0.85, @@ -1489,9 +1516,9 @@ def import_voices(files, saveAs=None, progress=None): beta=8.555504641634386, ) waveform = resampler(waveform) - sampling_rate = 44100 + sample_rate = 44100 - torchaudio.save(path, waveform, sampling_rate) + torchaudio.save(path, waveform, sample_rate) print(f"Running 'voicefixer' on voice sample: {path}") voicefixer.restore( @@ -1501,7 +1528,7 @@ def import_voices(files, saveAs=None, progress=None): #mode=mode, ) else: - torchaudio.save(path, waveform, sampling_rate) + torchaudio.save(path, waveform, sample_rate) print(f"Imported voice to {path}")