forked from mrq/ai-voice-cloning
(disabled by default until i validate it working) added additional transcription text normalization (something else I'm experimenting with requires it)
This commit is contained in:
parent
66ac8ba766
commit
32d968a8cd
26
src/utils.py
26
src/utils.py
|
@ -37,6 +37,8 @@ from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_
|
||||||
from tortoise.utils.text import split_and_recombine_text
|
from tortoise.utils.text import split_and_recombine_text
|
||||||
from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram, do_gc
|
from tortoise.utils.device import get_device_name, set_device_name, get_device_count, get_device_vram, do_gc
|
||||||
|
|
||||||
|
from whisper.normalizers.english import EnglishTextNormalizer
|
||||||
|
|
||||||
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
|
MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth"
|
||||||
|
|
||||||
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"]
|
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"]
|
||||||
|
@ -1193,7 +1195,7 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
|
||||||
messages.append(f"Sliced segments: {files} => {segments}.")
|
messages.append(f"Sliced segments: {files} => {segments}.")
|
||||||
return "\n".join(messages)
|
return "\n".join(messages)
|
||||||
|
|
||||||
def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=True ):
|
def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=False ):
|
||||||
indir = f'./training/{voice}/'
|
indir = f'./training/{voice}/'
|
||||||
infile = f'{indir}/whisper.json'
|
infile = f'{indir}/whisper.json'
|
||||||
messages = []
|
messages = []
|
||||||
|
@ -1208,6 +1210,9 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
||||||
'validation': [],
|
'validation': [],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
normalizer = EnglishTextNormalizer() if normalize else None
|
||||||
|
|
||||||
|
errored = 0
|
||||||
for filename in results:
|
for filename in results:
|
||||||
result = results[filename]
|
result = results[filename]
|
||||||
use_segment = use_segments
|
use_segment = use_segments
|
||||||
|
@ -1215,7 +1220,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
||||||
# check if unsegmented text exceeds 200 characters
|
# check if unsegmented text exceeds 200 characters
|
||||||
if not use_segment:
|
if not use_segment:
|
||||||
if len(result['text']) > 200:
|
if len(result['text']) > 200:
|
||||||
message = f"Text length too long (200 < {len(text)}), using segments: {filename}"
|
message = f"Text length too long (200 < {len(result['text'])}), using segments: {filename}"
|
||||||
print(message)
|
print(message)
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
use_segment = True
|
use_segment = True
|
||||||
|
@ -1225,6 +1230,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
||||||
path = f'{indir}/audio/{filename}'
|
path = f'{indir}/audio/{filename}'
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
messages.append(f"Missing source audio: {filename}")
|
messages.append(f"Missing source audio: {filename}")
|
||||||
|
errored += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
metadata = torchaudio.info(path)
|
metadata = torchaudio.info(path)
|
||||||
|
@ -1253,15 +1259,17 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
||||||
message = f"Missing source audio: {file}"
|
message = f"Missing source audio: {file}"
|
||||||
print(message)
|
print(message)
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
|
errored += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
text = segment['text'].strip()
|
text = segment['text'].strip()
|
||||||
normalized_text = text
|
normalized_text = normalizer(text) if normalize and result['language'] == "en" else text
|
||||||
|
|
||||||
if len(text) > 200:
|
if len(text) > 200:
|
||||||
message = f"Text length too long (200 < {len(text)}), skipping... {file}"
|
message = f"Text length too long (200 < {len(text)}), skipping... {file}"
|
||||||
print(message)
|
print(message)
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
|
errored += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
waveform, sample_rate = torchaudio.load(path)
|
waveform, sample_rate = torchaudio.load(path)
|
||||||
|
@ -1271,6 +1279,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
||||||
message = f"{error}, skipping... {file}"
|
message = f"{error}, skipping... {file}"
|
||||||
print(message)
|
print(message)
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
|
errored += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
culled = len(text) < text_length
|
culled = len(text) < text_length
|
||||||
|
@ -1279,8 +1288,13 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
||||||
duration = num_channels * num_frames / sample_rate
|
duration = num_channels * num_frames / sample_rate
|
||||||
culled = duration < audio_length
|
culled = duration < audio_length
|
||||||
|
|
||||||
# lines['training' if not culled else 'validation'].append(f'audio/{file}|{text}|{normalized_text}')
|
# for when i add in a little treat ;), as it requires normalized text
|
||||||
lines['training' if not culled else 'validation'].append(f'audio/{file}|{text}')
|
if normalize and length(normalized_text) < 200:
|
||||||
|
line = f'audio/{file}|{text}|{normalized_text}'
|
||||||
|
else:
|
||||||
|
line = f'audio/{file}|{text}'
|
||||||
|
|
||||||
|
lines['training' if not culled else 'validation'].append(line)
|
||||||
|
|
||||||
training_joined = "\n".join(lines['training'])
|
training_joined = "\n".join(lines['training'])
|
||||||
validation_joined = "\n".join(lines['validation'])
|
validation_joined = "\n".join(lines['validation'])
|
||||||
|
@ -1291,7 +1305,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
||||||
with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f:
|
with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f:
|
||||||
f.write(validation_joined)
|
f.write(validation_joined)
|
||||||
|
|
||||||
messages.append(f"Prepared {len(lines['training'])} lines (validation: {len(lines['validation'])}).\n{training_joined}\n\n{validation_joined}")
|
messages.append(f"Prepared {len(lines['training'])} lines (validation: {len(lines['validation'])}, culled: {errored}).\n{training_joined}\n\n{validation_joined}")
|
||||||
return "\n".join(messages)
|
return "\n".join(messages)
|
||||||
|
|
||||||
def calc_iterations( epochs, lines, batch_size ):
|
def calc_iterations( epochs, lines, batch_size ):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user