(disabled by default until i validate it working) added additional transcription text normalization (something else I'm experimenting with requires it)
This commit is contained in:
parent
66ac8ba766
commit
32d968a8cd
26
src/utils.py
26
src/utils.py
|
@ -37,6 +37,8 @@ from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_
|
|||
from tortoise.utils.text import split_and_recombine_text
|
||||
from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram, do_gc
|
||||
|
||||
from whisper.normalizers.english import EnglishTextNormalizer
|
||||
|
||||
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
|
||||
|
||||
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"]
|
||||
|
@ -1193,7 +1195,7 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
|
|||
messages.append(f"Sliced segments: {files} => {segments}.")
|
||||
return "\n".join(messages)
|
||||
|
||||
def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=True ):
|
||||
def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=False ):
|
||||
indir = f'./training/{voice}/'
|
||||
infile = f'{indir}/whisper.json'
|
||||
messages = []
|
||||
|
@ -1208,6 +1210,9 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
|||
'validation': [],
|
||||
}
|
||||
|
||||
normalizer = EnglishTextNormalizer() if normalize else None
|
||||
|
||||
errored = 0
|
||||
for filename in results:
|
||||
result = results[filename]
|
||||
use_segment = use_segments
|
||||
|
@ -1215,7 +1220,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
|||
# check if unsegmented text exceeds 200 characters
|
||||
if not use_segment:
|
||||
if len(result['text']) > 200:
|
||||
message = f"Text length too long (200 < {len(text)}), using segments: {filename}"
|
||||
message = f"Text length too long (200 < {len(result['text'])}), using segments: {filename}"
|
||||
print(message)
|
||||
messages.append(message)
|
||||
use_segment = True
|
||||
|
@ -1225,6 +1230,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
|||
path = f'{indir}/audio/{filename}'
|
||||
if not os.path.exists(path):
|
||||
messages.append(f"Missing source audio: {filename}")
|
||||
errored += 1
|
||||
continue
|
||||
|
||||
metadata = torchaudio.info(path)
|
||||
|
@ -1253,15 +1259,17 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
|||
message = f"Missing source audio: {file}"
|
||||
print(message)
|
||||
messages.append(message)
|
||||
errored += 1
|
||||
continue
|
||||
|
||||
text = segment['text'].strip()
|
||||
normalized_text = text
|
||||
normalized_text = normalizer(text) if normalize and result['language'] == "en" else text
|
||||
|
||||
if len(text) > 200:
|
||||
message = f"Text length too long (200 < {len(text)}), skipping... {file}"
|
||||
print(message)
|
||||
messages.append(message)
|
||||
errored += 1
|
||||
continue
|
||||
|
||||
waveform, sample_rate = torchaudio.load(path)
|
||||
|
@ -1271,6 +1279,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
|||
message = f"{error}, skipping... {file}"
|
||||
print(message)
|
||||
messages.append(message)
|
||||
errored += 1
|
||||
continue
|
||||
|
||||
culled = len(text) < text_length
|
||||
|
@ -1279,8 +1288,13 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
|||
duration = num_channels * num_frames / sample_rate
|
||||
culled = duration < audio_length
|
||||
|
||||
# lines['training' if not culled else 'validation'].append(f'audio/{file}|{text}|{normalized_text}')
|
||||
lines['training' if not culled else 'validation'].append(f'audio/{file}|{text}')
|
||||
# for when i add in a little treat ;), as it requires normalized text
|
||||
if normalize and length(normalized_text) < 200:
|
||||
line = f'audio/{file}|{text}|{normalized_text}'
|
||||
else:
|
||||
line = f'audio/{file}|{text}'
|
||||
|
||||
lines['training' if not culled else 'validation'].append(line)
|
||||
|
||||
training_joined = "\n".join(lines['training'])
|
||||
validation_joined = "\n".join(lines['validation'])
|
||||
|
@ -1291,7 +1305,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
|||
with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f:
|
||||
f.write(validation_joined)
|
||||
|
||||
messages.append(f"Prepared {len(lines['training'])} lines (validation: {len(lines['validation'])}).\n{training_joined}\n\n{validation_joined}")
|
||||
messages.append(f"Prepared {len(lines['training'])} lines (validation: {len(lines['validation'])}, culled: {errored}).\n{training_joined}\n\n{validation_joined}")
|
||||
return "\n".join(messages)
|
||||
|
||||
def calc_iterations( epochs, lines, batch_size ):
|
||||
|
|
Loading…
Reference in New Issue
Block a user