(disabled by default until i validate it working) added additional transcription text normalization (something else I'm experimenting with requires it)

This commit is contained in:
mrq 2023-03-13 19:07:23 +00:00
parent 66ac8ba766
commit 32d968a8cd

View File

@ -37,6 +37,8 @@ from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_
from tortoise.utils.text import split_and_recombine_text from tortoise.utils.text import split_and_recombine_text
from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram, do_gc from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram, do_gc
from whisper.normalizers.english import EnglishTextNormalizer
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth" MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"] WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"]
@ -1193,7 +1195,7 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
messages.append(f"Sliced segments: {files} => {segments}.") messages.append(f"Sliced segments: {files} => {segments}.")
return "\n".join(messages) return "\n".join(messages)
def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=True ): def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=False ):
indir = f'./training/{voice}/' indir = f'./training/{voice}/'
infile = f'{indir}/whisper.json' infile = f'{indir}/whisper.json'
messages = [] messages = []
@ -1208,6 +1210,9 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
'validation': [], 'validation': [],
} }
normalizer = EnglishTextNormalizer() if normalize else None
errored = 0
for filename in results: for filename in results:
result = results[filename] result = results[filename]
use_segment = use_segments use_segment = use_segments
@ -1215,7 +1220,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
# check if unsegmented text exceeds 200 characters # check if unsegmented text exceeds 200 characters
if not use_segment: if not use_segment:
if len(result['text']) > 200: if len(result['text']) > 200:
message = f"Text length too long (200 < {len(text)}), using segments: {filename}" message = f"Text length too long (200 < {len(result['text'])}), using segments: {filename}"
print(message) print(message)
messages.append(message) messages.append(message)
use_segment = True use_segment = True
@ -1225,6 +1230,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
path = f'{indir}/audio/{filename}' path = f'{indir}/audio/{filename}'
if not os.path.exists(path): if not os.path.exists(path):
messages.append(f"Missing source audio: {filename}") messages.append(f"Missing source audio: {filename}")
errored += 1
continue continue
metadata = torchaudio.info(path) metadata = torchaudio.info(path)
@ -1253,15 +1259,17 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
message = f"Missing source audio: {file}" message = f"Missing source audio: {file}"
print(message) print(message)
messages.append(message) messages.append(message)
errored += 1
continue continue
text = segment['text'].strip() text = segment['text'].strip()
normalized_text = text normalized_text = normalizer(text) if normalize and result['language'] == "en" else text
if len(text) > 200: if len(text) > 200:
message = f"Text length too long (200 < {len(text)}), skipping... {file}" message = f"Text length too long (200 < {len(text)}), skipping... {file}"
print(message) print(message)
messages.append(message) messages.append(message)
errored += 1
continue continue
waveform, sample_rate = torchaudio.load(path) waveform, sample_rate = torchaudio.load(path)
@ -1271,6 +1279,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
message = f"{error}, skipping... {file}" message = f"{error}, skipping... {file}"
print(message) print(message)
messages.append(message) messages.append(message)
errored += 1
continue continue
culled = len(text) < text_length culled = len(text) < text_length
@ -1279,8 +1288,13 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
duration = num_channels * num_frames / sample_rate duration = num_channels * num_frames / sample_rate
culled = duration < audio_length culled = duration < audio_length
# lines['training' if not culled else 'validation'].append(f'audio/{file}|{text}|{normalized_text}') # for when i add in a little treat ;), as it requires normalized text
lines['training' if not culled else 'validation'].append(f'audio/{file}|{text}') if normalize and length(normalized_text) < 200:
line = f'audio/{file}|{text}|{normalized_text}'
else:
line = f'audio/{file}|{text}'
lines['training' if not culled else 'validation'].append(line)
training_joined = "\n".join(lines['training']) training_joined = "\n".join(lines['training'])
validation_joined = "\n".join(lines['validation']) validation_joined = "\n".join(lines['validation'])
@ -1291,7 +1305,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f: with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f:
f.write(validation_joined) f.write(validation_joined)
messages.append(f"Prepared {len(lines['training'])} lines (validation: {len(lines['validation'])}).\n{training_joined}\n\n{validation_joined}") messages.append(f"Prepared {len(lines['training'])} lines (validation: {len(lines['validation'])}, culled: {errored}).\n{training_joined}\n\n{validation_joined}")
return "\n".join(messages) return "\n".join(messages)
def calc_iterations( epochs, lines, batch_size ): def calc_iterations( epochs, lines, batch_size ):