From ee1b048d07551b68b946b764f70f2a3daefece22 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 13 Mar 2023 04:26:00 +0000 Subject: [PATCH] when creating the train/validatio datasets, use segments if the main audio's duration is too long, and slice to make the segments if they don't exist --- src/utils.py | 123 +++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 91 insertions(+), 32 deletions(-) diff --git a/src/utils.py b/src/utils.py index 45f43d1..45f8e9f 100755 --- a/src/utils.py +++ b/src/utils.py @@ -51,6 +51,9 @@ LEARNING_RATE_SCHEDULE = [ 9, 18, 25, 33 ] RESAMPLERS = {} +MIN_TRAINING_DURATION = 0.6 +MAX_TRAINING_DURATION = 11.6097505669 + args = None tts = None tts_loading = False @@ -62,6 +65,9 @@ training_state = None current_voice = None def resample( waveform, input_rate, output_rate=44100 ): + # mono-ize + waveform = torch.mean(waveform, dim=0, keepdim=True) + if input_rate == output_rate: return waveform, output_rate @@ -1066,18 +1072,19 @@ def whisper_transcribe( file, language=None ): result['segments'].append(reparsed) return result -def validate_waveform( waveform, sample_rate ): +def validate_waveform( waveform, sample_rate, min_only=False ): if not torch.any(waveform < 0): return "Waveform is empty" num_channels, num_frames = waveform.shape duration = num_channels * num_frames / sample_rate - if duration < 0.6: - return "Duration too short ({:.3f} < 0.6s)".format(duration) + if duration < MIN_TRAINING_DURATION: + return "Duration too short ({:.3f}s < {:.3f}s)".format(duration, MIN_TRAINING_DURATION) - if duration > 11: - return "Duration too long (11s < {:.3f})".format(duration) + if not min_only: + if duration > MAX_TRAINING_DURATION: + return "Duration too long ({:.3f}s < {:.3f}s)".format(MAX_TRAINING_DURATION, duration) return @@ -1122,7 +1129,24 @@ def transcribe_dataset( voice, language=None, skip_existings=False, progress=Non return f"Processed dataset to: {indir}" -def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ): +def slice_waveform( waveform, sample_rate, start, end, trim ): + start = int(start * sample_rate) + end = int(end * sample_rate) + + if start < 0: + start = 0 + if end >= waveform.shape[-1]: + end = waveform.shape[-1] - 1 + + sliced = waveform[:, start:end] + + error = validate_waveform( sliced, sample_rate, min_only=True ) + if trim and not error: + sliced = torchaudio.functional.vad( sliced, sample_rate ) + + return sliced, error + +def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, results=None ): indir = f'./training/{voice}/' infile = f'{indir}/whisper.json' messages = [] @@ -1130,7 +1154,8 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ): if not os.path.exists(infile): raise Exception(f"Missing dataset: {infile}") - results = json.load(open(infile, 'r', encoding="utf-8")) + if results is None: + results = json.load(open(infile, 'r', encoding="utf-8")) files = 0 segments = 0 @@ -1140,37 +1165,35 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0 ): path = f'./training/{voice}/{filename}' if not os.path.exists(path): - messages.append(f"Missing source audio: {filename}") + message = f"Missing source audio: {filename}" + print(message) + messages.append(message) continue files += 1 result = results[filename] waveform, sample_rate = torchaudio.load(path) + num_channels, num_frames = waveform.shape + duration = num_channels * num_frames / sample_rate for segment in result['segments']: - start = int((segment['start'] + start_offset) * sample_rate) - end = int((segment['end'] + end_offset) * sample_rate) - - if start < 0: - start = 0 - if end >= waveform.shape[-1]: - end = waveform.shape[-1] - 1 - - sliced = waveform[:, start:end] file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") - - if trim_silence: - sliced = torchaudio.functional.vad( sliced, sample_rate ) - sliced, sample_rate = resample( sliced, sample_rate, 22050 ) - torchaudio.save(f"{indir}/audio/{file}", sliced, sample_rate) + sliced, error = slice_waveform( waveform, sample_rate, segment['start'] + start_offset, segment['end'] + end_offset, trim_silence ) + if error: + message = f"{error}, skipping... {file}" + print(message) + messages.append(message) + continue + sliced, _ = resample( sliced, sample_rate, 22050 ) + torchaudio.save(f"{indir}/audio/{file}", sliced, 22050) segments +=1 messages.append(f"Sliced segments: {files} => {segments}.") return "\n".join(messages) -def prepare_dataset( voice, use_segments, text_length, audio_length ): +def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=True ): indir = f'./training/{voice}/' infile = f'{indir}/whisper.json' messages = [] @@ -1187,31 +1210,67 @@ def prepare_dataset( voice, use_segments, text_length, audio_length ): for filename in results: result = results[filename] - segments = result['segments'] if use_segments else [{'text': result['text']}] - for segment in segments: - file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") if use_segments else filename - path = f'{indir}/audio/{file}' + use_segment = use_segments + + # check if unsegmented audio exceeds 11.6s + if not use_segment: + path = f'{indir}/audio/{filename}' if not os.path.exists(path): - messages.append(f"Missing source audio: {file}") + messages.append(f"Missing source audio: {filename}") + continue + + metadata = torchaudio.info(path) + duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate + if duration >= MAX_TRAINING_DURATION: + message = f"Audio too large, using segments: {filename}" + print(message) + messages.append(message) + use_segment = True + + segments = result['segments'] if use_segment else [{'text': result['text']}] + + for segment in segments: + file = filename.replace(".wav", f"_{pad(segment['id'], 4)}.wav") if use_segment else filename + path = f'{indir}/audio/{file}' + # segment when needed + if not os.path.exists(path): + tmp_results = {} + tmp_results[filename] = result + print(f"Audio not segmented, segmenting: {filename}") + message = slice_dataset( voice, results=tmp_results ) + print(message) + messages = messages + message.split("\n") + + if not os.path.exists(path): + message = f"Missing source audio: {file}" + print(message) + messages.append(message) continue text = segment['text'].strip() + normalized_text = text + if len(text) > 200: - messages.append(f"[{file}] Text length too long (200 < {len(text)}), skipping...") + message = f"Text length too long (200 < {len(text)}), skipping... {file}" + print(message) + messages.append(message) waveform, sample_rate = torchaudio.load(path) - num_channels, num_frames = waveform.shape - duration = num_channels * num_frames / sample_rate error = validate_waveform( waveform, sample_rate ) if error: - messages.append(f"[{file}]: {error}, skipping...") + message = f"{error}, skipping... {file}" + print(message) + messages.append(message) continue culled = len(text) < text_length if not culled and audio_length > 0: + num_channels, num_frames = waveform.shape + duration = num_channels * num_frames / sample_rate culled = duration < audio_length + # lines['training' if not culled else 'validation'].append(f'audio/{file}|{text}|{normalized_text}') lines['training' if not culled else 'validation'].append(f'audio/{file}|{text}') training_joined = "\n".join(lines['training'])