@ -438,7 +438,7 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
# superfluous, but it cleans up some things
class TrainingState ( ) :
def __init__ ( self , config_path , buffer_size = 8 ):
def __init__ ( self , config_path ):
self . cmd = [ ' train.bat ' , config_path ] if os . name == " nt " else [ ' bash ' , ' ./train.sh ' , config_path ]
# parse config to get its iteration
@ -465,7 +465,7 @@ class TrainingState():
self . training_started = False
self . info = { }
self . status = " "
self . status = " ... "
self . epoch_rate = " "
self . epoch_time_start = 0
@ -491,7 +491,7 @@ class TrainingState():
match = re . findall ( r ' iter: ([ \ d,]+) ' , line )
if match and len ( match ) > 0 :
self . it = int ( match [ 0 ] . replace ( " , " , " " ) )
el if progre ss is not Non e:
el se:
if line . find ( ' % | ' ) > 0 and not self . open_state :
self . open_state = True
elif line . find ( ' 100 % | ' ) == 0 and self . open_state :
@ -505,7 +505,12 @@ class TrainingState():
self . eta = ( self . epochs - self . epoch ) * self . epoch_time_delta
self . eta_hhmmss = str ( timedelta ( seconds = int ( self . eta ) ) )
progress ( self . epoch / float ( self . epochs ) , f ' [ { self . epoch } / { self . epochs } ] [ETA: { self . eta_hhmmss } ] { self . epoch_rate } Training... { self . status } ' )
percent = self . epoch / float ( self . epochs )
message = f ' [ { self . epoch } / { self . epochs } ] [ETA: { self . eta_hhmmss } ] { self . epoch_rate } { self . status } '
print ( f ' { " {:.3f} " . format ( percent * 100 ) } % { message } ' )
if progress is not None :
progress ( percent , message )
self . buffer . append ( f ' { " {:.3f} " . format ( percent * 100 ) } % { message } ' )
if line . find ( ' INFO: [epoch: ' ) > = 0 :
# easily rip out our stats...
@ -516,12 +521,20 @@ class TrainingState():
if ' loss_gpt_total ' in self . info :
self . status = f " Total loss at epoch { self . epoch } : { self . info [ ' loss_gpt_total ' ] } "
print ( self . status )
self . buffer . append ( self . status )
elif line . find ( ' Saving models and training states ' ) > = 0 :
self . checkpoint = self . checkpoint + 1
progress ( self . checkpoint / float ( self . checkpoints ) , f ' [ { self . checkpoint } / { self . checkpoints } ] Saving checkpoint... ' )
percent = self . checkpoint / float ( self . checkpoints )
message = f ' [ { self . checkpoint } / { self . checkpoints } ] Saving checkpoint... '
print ( f ' { " {:.3f} " . format ( percent * 100 ) } % { message } ' )
if progress is not None :
progress ( percent , message )
self . buffer . append ( f ' { " {:.3f} " . format ( percent * 100 ) } % { message } ' )
self . buffer = self . buffer [ - buffer_size : ]
if verbose or not self . training_started :
return " " . join ( self . buffer [ - buffer_size : ] )
return " " . join ( self . buffer )
def run_training ( config_path , verbose = False , buffer_size = 8 , progress = gr . Progress ( track_tqdm = True ) ) :
global training_state
@ -535,25 +548,22 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress
unload_whisper ( )
unload_voicefixer ( )
training_state = TrainingState ( config_path = config_path , buffer_size = buffer_size )
training_state = TrainingState ( config_path = config_path )
for line in iter ( training_state . process . stdout . readline , " " ) :
print ( f " [Training] [ { datetime . now ( ) . isoformat ( ) } ] { line [ : - 1 ] } " )
res = training_state . parse ( line = line , verbose = verbose , buffer_size = buffer_size , progress = progress )
print ( f " [Training] [ { datetime . now ( ) . isoformat ( ) } ] { line [ : - 1 ] } " )
if res :
yield res
training_state . process . stdout . close ( )
return_code = training_state . process . wait ( )
output = " " . join ( training_state . buffer [ - buffer_size : ] )
training_state = None
#if return_code:
# raise subprocess.CalledProcessError(return_code, cmd)
return output
def reconnect_training ( config_path , verbose = False , buffer_size = 8 , progress = gr . Progress ( track_tqdm = True ) ) :
global training_state
if not training_state or not training_state . process :
@ -564,10 +574,6 @@ def reconnect_training(config_path, verbose=False, buffer_size=8, progress=gr.Pr
if res :
yield res
output = " " . join ( training_state . buffer [ - buffer_size : ] )
return output
def stop_training ( ) :
global training_process
if training_process is None :
@ -644,7 +650,7 @@ EPOCH_SCHEDULE = [ 9, 18, 25, 33 ]
def schedule_learning_rate ( iterations ) :
return [ int ( iterations * d ) for d in EPOCH_SCHEDULE ]
def optimize_training_settings ( epochs , batch_size, learning_rate, learning_rate_schedul e, mega_batch_factor , print_rate , save_rate , resume_path , half_p , voice ) :
def optimize_training_settings ( epochs , learning_rate, learning_rate_schedul e, batch_siz e, mega_batch_factor , print_rate , save_rate , resume_path , half_p , voice ) :
name = f " { voice } -finetune "
dataset_name = f " { voice } -train "
dataset_path = f " ./training/ { voice } /train.txt "
@ -694,9 +700,9 @@ def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate
messages . append ( f " For { epochs } epochs with { lines } lines in batches of { batch_size } , iterating for { iterations } steps ( { int ( iterations / epochs ) } steps per epoch) " )
return (
batch_size ,
learning_rate ,
learning_rate_schedule ,
batch_size ,
mega_batch_factor ,
print_rate ,
save_rate ,
@ -704,7 +710,7 @@ def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate
messages
)
def save_training_settings ( iterations = None , batch_size= None , learning_rate= None , learning_rate_schedul e= None , mega_batch_factor = None , print_rate = None , save_rate = None , name = None , dataset_name = None , dataset_path = None , validation_name = None , validation_path = None , output_name = None , resume_path = None , half_p = None ) :
def save_training_settings ( iterations = None , learning_rate= None , learning_rate_schedul e= None , batch_siz e= None , mega_batch_factor = None , print_rate = None , save_rate = None , name = None , dataset_name = None , dataset_path = None , validation_name = None , validation_path = None , output_name = None , resume_path = None , half_p = None ) :
settings = {
" iterations " : iterations if iterations else 500 ,
" batch_size " : batch_size if batch_size else 64 ,