diff --git a/src/utils.py b/src/utils.py index 3da36af..a8c175a 100755 --- a/src/utils.py +++ b/src/utils.py @@ -37,6 +37,8 @@ from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_ from tortoise.utils.text import split_and_recombine_text from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram, do_gc +from whisper.normalizers.english import EnglishTextNormalizer + MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth" WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"] @@ -1193,7 +1195,7 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul messages.append(f"Sliced segments: {files} => {segments}.") return "\n".join(messages) -def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=True ): +def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=False ): indir = f'./training/{voice}/' infile = f'{indir}/whisper.json' messages = [] @@ -1208,6 +1210,9 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T 'validation': [], } + normalizer = EnglishTextNormalizer() if normalize else None + + errored = 0 for filename in results: result = results[filename] use_segment = use_segments @@ -1215,7 +1220,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T # check if unsegmented text exceeds 200 characters if not use_segment: if len(result['text']) > 200: - message = f"Text length too long (200 < {len(text)}), using segments: {filename}" + message = f"Text length too long (200 < {len(result['text'])}), using segments: {filename}" print(message) messages.append(message) use_segment = True @@ -1225,6 +1230,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T path = f'{indir}/audio/{filename}' if not os.path.exists(path): messages.append(f"Missing source audio: {filename}") + errored += 1 continue metadata = torchaudio.info(path) @@ -1253,15 +1259,17 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T message = f"Missing source audio: {file}" print(message) messages.append(message) + errored += 1 continue text = segment['text'].strip() - normalized_text = text + normalized_text = normalizer(text) if normalize and result['language'] == "en" else text if len(text) > 200: message = f"Text length too long (200 < {len(text)}), skipping... {file}" print(message) messages.append(message) + errored += 1 continue waveform, sample_rate = torchaudio.load(path) @@ -1271,6 +1279,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T message = f"{error}, skipping... {file}" print(message) messages.append(message) + errored += 1 continue culled = len(text) < text_length @@ -1279,8 +1288,13 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T duration = num_channels * num_frames / sample_rate culled = duration < audio_length - # lines['training' if not culled else 'validation'].append(f'audio/{file}|{text}|{normalized_text}') - lines['training' if not culled else 'validation'].append(f'audio/{file}|{text}') + # for when i add in a little treat ;), as it requires normalized text + if normalize and length(normalized_text) < 200: + line = f'audio/{file}|{text}|{normalized_text}' + else: + line = f'audio/{file}|{text}' + + lines['training' if not culled else 'validation'].append(line) training_joined = "\n".join(lines['training']) validation_joined = "\n".join(lines['validation']) @@ -1291,7 +1305,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f: f.write(validation_joined) - messages.append(f"Prepared {len(lines['training'])} lines (validation: {len(lines['validation'])}).\n{training_joined}\n\n{validation_joined}") + messages.append(f"Prepared {len(lines['training'])} lines (validation: {len(lines['validation'])}, culled: {errored}).\n{training_joined}\n\n{validation_joined}") return "\n".join(messages) def calc_iterations( epochs, lines, batch_size ):