@ -37,6 +37,8 @@ from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_
from tortoise . utils . text import split_and_recombine_text
from tortoise . utils . device import get_device_name , set_device_name , get_device_count , get_device_vram , do_gc
from whisper . normalizers . english import EnglishTextNormalizer
MODELS [ ' dvae.pth ' ] = " https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth "
WHISPER_MODELS = [ " tiny " , " base " , " small " , " medium " , " large " ]
@ -1193,7 +1195,7 @@ def slice_dataset( voice, trim_silence=True, start_offset=0, end_offset=0, resul
messages . append ( f " Sliced segments: { files } => { segments } . " )
return " \n " . join ( messages )
def prepare_dataset ( voice , use_segments , text_length , audio_length , normalize = Tru e ) :
def prepare_dataset ( voice , use_segments , text_length , audio_length , normalize = Fals e ) :
indir = f ' ./training/ { voice } / '
infile = f ' { indir } /whisper.json '
messages = [ ]
@ -1208,6 +1210,9 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
' validation ' : [ ] ,
}
normalizer = EnglishTextNormalizer ( ) if normalize else None
errored = 0
for filename in results :
result = results [ filename ]
use_segment = use_segments
@ -1215,7 +1220,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
# check if unsegmented text exceeds 200 characters
if not use_segment :
if len ( result [ ' text ' ] ) > 200 :
message = f " Text length too long (200 < { len ( text) } ), using segments: { filename } "
message = f " Text length too long (200 < { len ( result[ ' text' ] ) } ), using segments: { filename } "
print ( message )
messages . append ( message )
use_segment = True
@ -1225,6 +1230,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
path = f ' { indir } /audio/ { filename } '
if not os . path . exists ( path ) :
messages . append ( f " Missing source audio: { filename } " )
errored + = 1
continue
metadata = torchaudio . info ( path )
@ -1253,15 +1259,17 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
message = f " Missing source audio: { file } "
print ( message )
messages . append ( message )
errored + = 1
continue
text = segment [ ' text ' ] . strip ( )
normalized_text = text
normalized_text = normalizer( text ) if normalize and result [ ' language ' ] == " en " else text
if len ( text ) > 200 :
message = f " Text length too long (200 < { len ( text ) } ), skipping... { file } "
print ( message )
messages . append ( message )
errored + = 1
continue
waveform , sample_rate = torchaudio . load ( path )
@ -1271,6 +1279,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
message = f " { error } , skipping... { file } "
print ( message )
messages . append ( message )
errored + = 1
continue
culled = len ( text ) < text_length
@ -1279,8 +1288,13 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
duration = num_channels * num_frames / sample_rate
culled = duration < audio_length
# lines['training' if not culled else 'validation'].append(f'audio/{file}|{text}|{normalized_text}')
lines [ ' training ' if not culled else ' validation ' ] . append ( f ' audio/ { file } | { text } ' )
# for when i add in a little treat ;), as it requires normalized text
if normalize and length ( normalized_text ) < 200 :
line = f ' audio/ { file } | { text } | { normalized_text } '
else :
line = f ' audio/ { file } | { text } '
lines [ ' training ' if not culled else ' validation ' ] . append ( line )
training_joined = " \n " . join ( lines [ ' training ' ] )
validation_joined = " \n " . join ( lines [ ' validation ' ] )
@ -1291,7 +1305,7 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
with open ( f ' { indir } /validation.txt ' , ' w ' , encoding = " utf-8 " ) as f :
f . write ( validation_joined )
messages . append ( f " Prepared { len ( lines [ ' training ' ] ) } lines (validation: { len ( lines [ ' validation ' ] ) } ).\n { training_joined } \n \n { validation_joined } " )
messages . append ( f " Prepared { len ( lines [ ' training ' ] ) } lines (validation: { len ( lines [ ' validation ' ] ) } , culled: { errored } ).\n { training_joined } \n \n { validation_joined } " )
return " \n " . join ( messages )
def calc_iterations ( epochs , lines , batch_size ) :