@ -10,6 +10,7 @@ if 'TRANSFORMERS_CACHE' not in os.environ:
import argparse
import time
import math
import json
import base64
import re
@ -42,11 +43,6 @@ from whisper.normalizers.english import EnglishTextNormalizer
from whisper . normalizers . basic import BasicTextNormalizer
from whisper . tokenizer import LANGUAGES
try :
from phonemizer import phonemize as phonemizer
except Exception as e :
pass
MODELS [ ' dvae.pth ' ] = " https://huggingface.co/jbetker/tortoise-tts-v2/resolve/3704aea61678e7e468a06d8eea121dba368a798e/.models/dvae.pth "
WHISPER_MODELS = [ " tiny " , " base " , " small " , " medium " , " large " ]
@ -340,6 +336,9 @@ def generate(**kwargs):
INFERENCING = True
for line , cut_text in enumerate ( texts ) :
if should_phonemize ( ) :
cut_text = phonemizer ( cut_text )
if parameters [ ' emotion ' ] == " Custom " :
if parameters [ ' prompt ' ] and parameters [ ' prompt ' ] . strip ( ) != " " :
cut_text = f " [ { parameters [ ' prompt ' ] } ,] { cut_text } "
@ -636,46 +635,31 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog
# superfluous, but it cleans up some things
class TrainingState ( ) :
def __init__ ( self , config_path , keep_x_past_checkpoints = 0 , start = True ) :
# parse config to get its iteration
self . killed = False
self . training_dir = os . path . dirname ( config_path )
with open ( config_path , ' r ' ) as file :
self . config = yaml . safe_load ( file )
self . yaml_ config = yaml . safe_load ( file )
self . json_config = json . load ( open ( f " { self . training_dir } /train.json " , ' r ' , encoding = " utf-8 " ) )
self . dataset_dir = f " { self . training_dir } /finetune/ "
self . dataset_path = f " { self . training_dir } /train.txt "
with open ( self . dataset_path , ' r ' , encoding = " utf-8 " ) as f :
self . dataset_size = len ( f . readlines ( ) )
self . killed = False
self . batch_size = self . json_config [ " batch_size " ]
self . save_rate = self . json_config [ " save_rate " ]
self . epoch = 0
self . epochs = self . json_config [ " epochs " ]
self . it = 0
self . its = calc_iterations ( self . epochs , self . dataset_size , self . batch_size )
self . step = 0
self . epoch = 0
self . steps = int ( self . its / self . dataset_size )
self . checkpoint = 0
self . checkpoints = int ( ( self . its - self . it ) / self . save_rate )
if args . tts_backend == " tortoise " :
gpus = self . config [ " gpus " ]
self . dataset_dir = f " ./training/ { self . config [ ' name ' ] } /finetune/ "
self . batch_size = self . config [ ' datasets ' ] [ ' train ' ] [ ' batch_size ' ]
self . dataset_path = self . config [ ' datasets ' ] [ ' train ' ] [ ' path ' ]
with open ( self . dataset_path , ' r ' , encoding = " utf-8 " ) as f :
self . dataset_size = len ( f . readlines ( ) )
self . its = self . config [ ' train ' ] [ ' niter ' ]
self . steps = 1
self . epochs = int ( self . its * self . batch_size / self . dataset_size )
self . checkpoints = int ( self . its / self . config [ ' logger ' ] [ ' save_checkpoint_freq ' ] )
elif args . tts_backend == " vall-e " :
self . batch_size = self . config [ ' batch_size ' ]
self . dataset_dir = f " . { self . config [ ' data_root ' ] } /finetune/ "
self . dataset_path = f " { self . config [ ' data_root ' ] } /train.txt "
self . its = 1
self . steps = 1
self . epochs = 1
self . checkpoints = 1
with open ( self . dataset_path , ' r ' , encoding = " utf-8 " ) as f :
self . dataset_size = len ( f . readlines ( ) )
self . json_config = json . load ( open ( f " { self . config [ ' data_root ' ] } /train.json " , ' r ' , encoding = " utf-8 " ) )
gpus = self . json_config [ ' gpus ' ]
self . gpus = self . json_config [ ' gpus ' ]
self . buffer = [ ]
@ -706,12 +690,15 @@ class TrainingState():
' loss ' : " " ,
}
self . buffer_json = None
self . json_buffer = [ ]
self . loss_milestones = [ 1.0 , 0.15 , 0.05 ]
if keep_x_past_checkpoints > 0 :
self . cleanup_old ( keep = keep_x_past_checkpoints )
if start :
self . spawn_process ( config_path = config_path , gpus = gpus )
self . spawn_process ( config_path = config_path , gpus = self . gpus )
def spawn_process ( self , config_path , gpus = 1 ) :
if args . tts_backend == " vall-e " :
@ -771,14 +758,19 @@ class TrainingState():
if ' lr ' in self . info :
self . statistics [ ' lr ' ] . append ( { ' epoch ' : epoch , ' it ' : self . it , ' value ' : self . info [ ' lr ' ] , ' type ' : ' learning_rate ' } )
for k in [ ' loss_text_ce ' , ' loss_mel_ce ' , ' loss_gpt_total ' ] :
if k not in self . info :
continue
if args . tts_backend == " tortoise " :
for k in [ ' loss_text_ce ' , ' loss_mel_ce ' , ' loss_gpt_total ' ] :
if k not in self . info :
continue
if k == " loss_gpt_total " :
self . losses . append ( self . statistics [ ' loss ' ] [ - 1 ] )
else :
self . statistics [ ' loss ' ] . append ( { ' epoch ' : epoch , ' it ' : self . it , ' value ' : self . info [ k ] , ' type ' : f ' { " val_ " if data [ " mode " ] == " validation " else " " } { k } ' } )
if k == " loss_gpt_total " :
self . losses . append ( self . statistics [ ' loss ' ] [ - 1 ] )
else :
self . statistics [ ' loss ' ] . append ( { ' epoch ' : epoch , ' it ' : self . it , ' value ' : self . info [ k ] , ' type ' : f ' { " val_ " if data [ " mode " ] == " validation " else " " } { k } ' } )
else :
k = " loss "
self . statistics [ ' loss ' ] . append ( { ' epoch ' : epoch , ' it ' : self . it , ' value ' : self . info [ k ] , ' type ' : f ' { " val_ " if data [ " mode " ] == " validation " else " " } { k } ' } )
self . losses . append ( self . statistics [ ' loss ' ] [ - 1 ] )
return data
@ -916,18 +908,62 @@ class TrainingState():
print ( " Removing " , path )
os . remove ( path )
def parse_valle_metrics ( self , data ) :
res = { }
res [ ' mode ' ] = " training "
res [ ' loss ' ] = data [ ' model.loss ' ]
res [ ' lr ' ] = data [ ' model.lr ' ]
res [ ' it ' ] = data [ ' global_step ' ]
res [ ' step ' ] = res [ ' it ' ] % self . dataset_size
res [ ' steps ' ] = self . steps
res [ ' epoch ' ] = int ( res [ ' it ' ] / self . dataset_size )
res [ ' iteration_rate ' ] = data [ ' elapsed_time ' ]
return res
def parse ( self , line , verbose = False , keep_x_past_checkpoints = 0 , buffer_size = 8 , progress = None ) :
self . buffer . append ( f ' { line } ' )
should_return = False
data = Non e
percent = 0
message = None
should_return = False
MESSAGE_START = ' Start training from epoch '
MESSAGE_FINSIHED = ' Finished training '
MESSAGE_SAVING = ' INFO: Saving models and training states. '
MESSAGE_METRICS_TRAINING = ' INFO: Training Metrics: '
MESSAGE_METRICS_VALIDATION = ' INFO: Validation Metrics: '
if args . tts_backend == " vall-e " :
if self . buffer_json :
self . json_buffer . append ( line )
if line . find ( " { " ) == 0 and not self . buffer_json :
self . buffer_json = True
self . json_buffer = [ line ]
if line . find ( " } " ) == 0 and self . buffer_json :
try :
data = json . loads ( " \n " . join ( self . json_buffer ) )
except Exception as e :
print ( str ( e ) )
if line . find ( ' Finished training ' ) > = 0 :
if data and ' model.loss ' in data :
self . training_started = True
data = self . parse_valle_metrics ( data )
print ( " Training JSON: " , data )
else :
data = None
self . buffer_json = None
self . json_buffer = [ ]
if line . find ( MESSAGE_FINSIHED ) > = 0 :
self . killed = True
# rip out iteration info
elif not self . training_started :
if line . find ( ' Start training from epoch ' ) > = 0 :
if line . find ( MESSAGE_START ) > = 0 :
self . training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
match = re . findall ( r ' epoch: ([ \ d,]+) ' , line )
@ -937,40 +973,39 @@ class TrainingState():
if match and len ( match ) > 0 :
self . it = int ( match [ 0 ] . replace ( " , " , " " ) )
self . checkpoints = int ( ( self . its - self . it ) / self . config[ ' logger ' ] [ ' save_checkpoint_freq ' ] )
self . checkpoints = int ( ( self . its - self . it ) / self . save_rate )
self . load_statistics ( )
should_return = True
else :
data = None
if line . find ( ' INFO: Saving models and training states. ' ) > = 0 :
if line . find ( MESSAGE_SAVING ) > = 0 :
self . checkpoint + = 1
message = f " [ { self . checkpoint } / { self . checkpoints } ] Saving checkpoint... "
percent = self . checkpoint / self . checkpoints
self . cleanup_old ( keep = keep_x_past_checkpoints )
elif line . find ( ' INFO: Training Metrics: ' ) > = 0 :
data = json . loads ( line . split ( " INFO: Training Metrics: " ) [ - 1 ] )
elif line . find ( MESSAGE_METRICS_TRAINING ) > = 0 :
data = json . loads ( line . split ( MESSAGE_METRICS_TRAINING ) [ - 1 ] )
data [ ' mode ' ] = " training "
elif line . find ( ' INFO: Validation Metrics: ' ) > = 0 :
data = json . loads ( line . split ( " INFO: Validation Metrics: " ) [ - 1 ] )
elif line . find ( MESSAGE_METRICS_VALIDATION ) > = 0 :
data = json . loads ( line . split ( MESSAGE_METRICS_VALIDATION ) [ - 1 ] )
data [ ' mode ' ] = " validation "
if data is not None :
if ' : nan ' in line and not self . nan_detected :
self . nan_detected = self . it
self . parse_metrics ( data )
message = self . get_status ( )
if message :
percent = self . it / float ( self . its ) # self.epoch / float(self.epochs)
if progress is not None :
progress ( percent , message )
if data is not None :
if ' : nan ' in line and not self . nan_detected :
self . nan_detected = self . it
self . parse_metrics ( data )
message = self . get_status ( )
if message :
percent = self . it / float ( self . its ) # self.epoch / float(self.epochs)
if progress is not None :
progress ( percent , message )
self . buffer . append ( f ' [ { " {:.3f} " . format ( percent * 100 ) } %] { message } ' )
should_return = True
self . buffer . append ( f ' [ { " {:.3f} " . format ( percent * 100 ) } %] { message } ' )
should_return = True
if verbose and not self . training_started :
should_return = True
@ -1278,7 +1313,7 @@ def phonemize_txt_file( path ):
audio = split [ 0 ]
text = split [ 2 ]
phonemes = phonemizer ( text , preserve_punctuation = True , strip = True )
phonemes = phonemizer ( text )
reparsed . append ( f ' { audio } | { phonemes } ' )
f . write ( f ' \n { audio } | { phonemes } ' )
@ -1321,6 +1356,21 @@ def create_dataset_json( path ):
with open ( path . replace ( " .txt " , " .json " ) , ' w ' , encoding = ' utf-8 ' ) as f :
f . write ( json . dumps ( data , indent = " \t " ) )
def phonemizer ( text , language = " en-us " ) :
from phonemizer import phonemize
if language == " english " :
language = " en-us "
return phonemize ( text , language = language , strip = True , preserve_punctuation = True , with_stress = True , backend = args . phonemizer_backend )
def should_phonemize ( ) :
try :
from phonemizer import phonemize
except Exception as e :
print ( str ( e ) )
return False
return args . tokenizer_json is not None and args . tokenizer_json [ - 8 : ] == " ipa.json "
def prepare_dataset ( voice , use_segments = False , text_length = 0 , audio_length = 0 , progress = gr . Progress ( ) ) :
indir = f ' ./training/ { voice } / '
infile = f ' { indir } /whisper.json '
@ -1332,7 +1382,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
errored = 0
messages = [ ]
normalize = True
phonemize = args. tokenizer_json is not None and args . tokenizer_json [ - 8 : ] == " ipa.json "
phonemize = should_phonemize( )
lines = { ' training ' : [ ] , ' validation ' : [ ] }
segments = { }
@ -1374,7 +1424,12 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
if use_segment and not use_segments :
exists = True
for segment in result [ ' segments ' ] :
if os . path . exists ( filename . replace ( " .wav " , f " _ { pad ( segment [ ' id ' ] , 4 ) } .wav " ) ) :
duration = segment [ ' end ' ] - segment [ ' start ' ]
if duration < = MIN_TRAINING_DURATION or MAX_TRAINING_DURATION < = duration :
continue
path = f ' { indir } /audio/ ' + filename . replace ( " .wav " , f " _ { pad ( segment [ ' id ' ] , 4 ) } .wav " )
if os . path . exists ( path ) :
continue
exists = False
break
@ -1396,6 +1451,10 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
}
else :
for segment in result [ ' segments ' ] :
duration = segment [ ' end ' ] - segment [ ' start ' ]
if duration < = MIN_TRAINING_DURATION or MAX_TRAINING_DURATION < = duration :
continue
segments [ filename . replace ( " .wav " , f " _ { pad ( segment [ ' id ' ] , 4 ) } .wav " ) ] = {
' text ' : segment [ ' text ' ] ,
' language ' : language ,
@ -1412,7 +1471,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
normalizer = result [ ' normalizer ' ]
phonemes = result [ ' phonemes ' ]
if phonemize and phonemes is None :
phonemes = phonemizer ( text , language = language if language != " english " else " en-us " , strip = True , preserve_punctuation = True , with_stress = True , backend = args . phonemizer_backend )
phonemes = phonemizer ( text , language = language )
if phonemize :
text = phonemes
@ -1456,7 +1515,10 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
print ( " Quantized: " , file )
tokens = tokenize_text ( text , stringed = False , skip_specials = True )
open ( f ' { indir } /valle/ { file . replace ( " .wav " , " .phn.txt " ) } ' , ' w ' , encoding = ' utf-8 ' ) . write ( " " . join ( tokens ) . replace ( " \u02C8 " , " \u02C8 " ) )
tokenized = " " . join ( tokens )
tokenized = tokenized . replace ( " \u02C8 " , " \u02C8 " )
tokenized = tokenized . replace ( " \u02CC " , " \u02CC " )
open ( f ' { indir } /valle/ { file . replace ( " .wav " , " .phn.txt " ) } ' , ' w ' , encoding = ' utf-8 ' ) . write ( tokenized )
training_joined = " \n " . join ( lines [ ' training ' ] )
validation_joined = " \n " . join ( lines [ ' validation ' ] )
@ -1471,8 +1533,7 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
return " \n " . join ( messages )
def calc_iterations ( epochs , lines , batch_size ) :
iterations = int ( epochs * lines / float ( batch_size ) )
return iterations
return int ( math . ceil ( epochs * math . ceil ( lines / batch_size ) ) )
def schedule_learning_rate ( iterations , schedule = LEARNING_RATE_SCHEDULE ) :
return [ int ( iterations * d ) for d in schedule ]
@ -1580,7 +1641,9 @@ def optimize_training_settings( **kwargs ):
if not os . path . exists ( get_halfp_model_path ( ) ) :
convert_to_halfp ( )
messages . append ( f " For { settings [ ' epochs ' ] } epochs with { lines } lines in batches of { settings [ ' batch_size ' ] } , iterating for { iterations } steps ( { int ( iterations / settings [ ' epochs ' ] ) } steps per epoch) " )
settings [ ' steps ' ] = int ( iterations / settings [ ' epochs ' ] )
messages . append ( f " For { settings [ ' epochs ' ] } epochs with { lines } lines in batches of { settings [ ' batch_size ' ] } , iterating for { iterations } steps ( { settings [ ' steps ' ] } ) steps per epoch) " )
return settings , messages
@ -1588,6 +1651,7 @@ def save_training_settings( **kwargs ):
messages = [ ]
settings = { }
settings . update ( kwargs )
outjson = f ' ./training/ { settings [ " voice " ] } /train.json '
with open ( outjson , ' w ' , encoding = " utf-8 " ) as f :
@ -1599,6 +1663,8 @@ def save_training_settings( **kwargs ):
with open ( settings [ ' dataset_path ' ] , ' r ' , encoding = " utf-8 " ) as f :
lines = len ( f . readlines ( ) )
settings [ ' iterations ' ] = calc_iterations ( epochs = settings [ ' epochs ' ] , lines = lines , batch_size = settings [ ' batch_size ' ] )
if not settings [ ' source_model ' ] or settings [ ' source_model ' ] == " auto " :
settings [ ' source_model ' ] = f " ./models/tortoise/autoregressive { ' _half ' if settings [ ' half_p ' ] else ' ' } .pth "
@ -1606,7 +1672,6 @@ def save_training_settings( **kwargs ):
if not os . path . exists ( get_halfp_model_path ( ) ) :
convert_to_halfp ( )
settings [ ' iterations ' ] = calc_iterations ( epochs = settings [ ' epochs ' ] , lines = lines , batch_size = settings [ ' batch_size ' ] )
messages . append ( f " For { settings [ ' epochs ' ] } epochs with { lines } lines, iterating for { settings [ ' iterations ' ] } steps " )
iterations_per_epoch = settings [ ' iterations ' ] / settings [ ' epochs ' ]
@ -1622,15 +1687,14 @@ def save_training_settings( **kwargs ):
if settings [ ' validation_rate ' ] < 1 :
settings [ ' validation_rate ' ] = 1
"""
settings [ ' validation_batch_size ' ] = int ( settings [ ' batch_size ' ] / settings [ ' gradient_accumulation_size ' ] )
settings [ ' iterations ' ] = calc_iterations ( epochs = settings [ ' epochs ' ] , lines = lines , batch_size = settings [ ' batch_size ' ] )
"""
if settings [ ' iterations ' ] % settings [ ' save_rate ' ] != 0 :
adjustment = int ( settings [ ' iterations ' ] / settings [ ' save_rate ' ] ) * settings [ ' save_rate ' ]
messages . append ( f " Iteration rate is not evenly divisible by save rate, adjusting: { settings [ ' iterations ' ] } => { adjustment } " )
settings [ ' iterations ' ] = adjustment
"""
settings [ ' validation_batch_size ' ] = int ( settings [ ' batch_size ' ] / settings [ ' gradient_accumulation_size ' ] )
if not os . path . exists ( settings [ ' validation_path ' ] ) :
settings [ ' validation_enabled ' ] = False
messages . append ( " Validation not found, disabling validation... " )
@ -1833,7 +1897,7 @@ def tokenize_text( text, stringed=True, skip_specials=False ):
tts . tokenizer
encoded = tokenizer . encode ( text )
decoded = tokenizer . tokenizer . decode ( encoded , skip_special_tokens = s pecials) . split ( " " )
decoded = tokenizer . tokenizer . decode ( encoded , skip_special_tokens = s kip_s pecials) . split ( " " )
if stringed :
return " \n " . join ( [ str ( encoded ) , str ( decoded ) ] )