@ -506,6 +506,8 @@ class TrainingState():
with open ( config_path , ' r ' ) as file :
with open ( config_path , ' r ' ) as file :
self . config = yaml . safe_load ( file )
self . config = yaml . safe_load ( file )
self . killed = False
self . dataset_dir = f " ./training/ { self . config [ ' name ' ] } / "
self . dataset_dir = f " ./training/ { self . config [ ' name ' ] } / "
self . batch_size = self . config [ ' datasets ' ] [ ' train ' ] [ ' batch_size ' ]
self . batch_size = self . config [ ' datasets ' ] [ ' train ' ] [ ' batch_size ' ]
self . dataset_path = self . config [ ' datasets ' ] [ ' train ' ] [ ' path ' ]
self . dataset_path = self . config [ ' datasets ' ] [ ' train ' ] [ ' path ' ]
@ -527,7 +529,6 @@ class TrainingState():
self . training_started = False
self . training_started = False
self . info = { }
self . info = { }
self . status = " ... "
self . epoch_rate = " "
self . epoch_rate = " "
self . epoch_time_start = 0
self . epoch_time_start = 0
@ -651,10 +652,12 @@ class TrainingState():
print ( " Removing " , path )
print ( " Removing " , path )
os . remove ( path )
os . remove ( path )
def parse ( self , line , verbose = False , buffer_size= 8 , keep_x_past_datasets= 0 , progress = None ) :
def parse ( self , line , verbose = False , keep_x_past_datasets= 0 , buffer_size = 8 , progress = None ) :
self . buffer . append ( f ' { line } ' )
self . buffer . append ( f ' { line } ' )
should_return = False
should_return = False
percent = 0
message = None
# rip out iteration info
# rip out iteration info
if not self . training_started :
if not self . training_started :
@ -679,7 +682,7 @@ class TrainingState():
match = re . findall ( r ' ( \ d+) % \ |(.+?) \ | ( \ d+| \ ?) \ /( \ d+| \ ?) \ [(.+?)<(.+?), +(.+?) \ ] ' , line )
match = re . findall ( r ' ( \ d+) % \ |(.+?) \ | ( \ d+| \ ?) \ /( \ d+| \ ?) \ [(.+?)<(.+?), +(.+?) \ ] ' , line )
if match and len ( match ) > 0 :
if match and len ( match ) > 0 :
match = match [ 0 ]
match = match [ 0 ]
per cent = int ( match [ 0 ] ) / 100.0
per _ cent = int ( match [ 0 ] ) / 100.0
progressbar = match [ 1 ]
progressbar = match [ 1 ]
step = int ( match [ 2 ] )
step = int ( match [ 2 ] )
steps = int ( match [ 3 ] )
steps = int ( match [ 3 ] )
@ -698,15 +701,40 @@ class TrainingState():
self . it_time_end = time . time ( )
self . it_time_end = time . time ( )
self . it_time_delta = self . it_time_end - self . it_time_start
self . it_time_delta = self . it_time_end - self . it_time_start
self . it_time_start = time . time ( )
self . it_time_start = time . time ( )
self . it_taken = self . it_taken + 1
try :
try :
rate = f ' { " {:.3f} " . format ( self . it_time_delta ) } s/it ' if self . it_time_delta > = 1 else f ' { " {:.3f} " . format ( 1 / self . it_time_delta ) } it/s '
rate = f ' { " {:.3f} " . format ( self . it_time_delta ) } s/it ' if self . it_time_delta > = 1 else f ' { " {:.3f} " . format ( 1 / self . it_time_delta ) } it/s '
self . it_rate = rate
self . it_rate = rate
except Exception as e :
except Exception as e :
pass
pass
last_loss = " "
metric_step = [ f " { self . epoch } / { self . epochs } " , f " { self . it } / { self . its } " , f " { step } / { steps } " ]
metric_step = " , " . join ( metric_step )
metric_rate = [ ]
if self . epoch_rate :
metric_rate . append ( self . epoch_rate )
if self . it_rate :
metric_rate . append ( self . it_rate )
metric_rate = " , " . join ( metric_rate )
eta_hhmmss = " ? "
if self . eta_hhmmss :
eta_hhmmss = self . eta_hhmmss
else :
try :
eta = ( self . its - self . it ) * ( self . it_time_deltas / self . it_taken )
eta = str ( timedelta ( seconds = int ( eta ) ) )
eta_hhmmss = eta
except Exception as e :
pass
metric_loss = [ ]
if len ( self . losses ) > 0 :
if len ( self . losses ) > 0 :
last_loss = f ' [Loss @ it. { self . losses [ - 1 ] [ " step " ] } : { self . losses [ - 1 ] [ " value " ] } ] '
metric_loss . append ( f ' Loss: { " {:3f} " . format ( self . losses [ - 1 ] [ " value " ] ) } ' )
message = f ' [ { self . epoch } / { self . epochs } , { self . it } / { self . its } , { step } / { steps } ] [ { self . epoch_rate } , { self . it_rate } ] { last_loss } [ETA: { self . eta_hhmmss } ] '
metric_loss = " , " . join ( metric_loss )
message = f ' [ { metric_step } ] [ { metric_rate } ] [ { metric_loss } ] [ETA: { eta_hhmmss } ] '
if lapsed :
if lapsed :
self . epoch = self . epoch + 1
self . epoch = self . epoch + 1
@ -741,16 +769,8 @@ class TrainingState():
for k , v in match :
for k , v in match :
self . info [ k ] = float ( v . replace ( " , " , " " ) )
self . info [ k ] = float ( v . replace ( " , " , " " ) )
if ' loss_gpt_total ' in self . info :
self . status = f " Total loss at epoch { self . epoch } : { self . info [ ' loss_gpt_total ' ] } "
"""
self . losses . append ( { " step " : self . it , " value " : self . info [ ' loss_text_ce ' ] , " type " : " loss_text_ce " } )
self . losses . append ( { " step " : self . it , " value " : self . info [ ' loss_mel_ce ' ] , " type " : " loss_mel_ce " } )
self . losses . append ( { " step " : self . it , " value " : self . info [ ' loss_gpt_total ' ] , " type " : " loss_gpt_total " } )
"""
should_return = True
self . load_losses ( update = True )
self . load_losses ( update = True )
should_return = True
elif line . find ( ' Saving models and training states ' ) > = 0 :
elif line . find ( ' Saving models and training states ' ) > = 0 :
self . checkpoint = self . checkpoint + 1
self . checkpoint = self . checkpoint + 1
@ -769,10 +789,18 @@ class TrainingState():
should_return = True
should_return = True
self . buffer = self . buffer [ - buffer_size : ]
self . buffer = self . buffer [ - buffer_size : ]
result = None
if should_return :
if should_return :
return " " . join ( self . buffer )
result = " " . join ( self . buffer ) if not self . training_started else message
def run_training ( config_path , verbose = False , gpus = 1 , buffer_size = 8 , keep_x_past_datasets = 0 , progress = gr . Progress ( track_tqdm = True ) ) :
return (
result ,
percent ,
message ,
)
def run_training ( config_path , verbose = False , gpus = 1 , keep_x_past_datasets = 0 , progress = gr . Progress ( track_tqdm = True ) ) :
global training_state
global training_state
if training_state and training_state . process :
if training_state and training_state . process :
return " Training already in progress "
return " Training already in progress "
@ -787,11 +815,10 @@ def run_training(config_path, verbose=False, gpus=1, buffer_size=8, keep_x_past_
training_state = TrainingState ( config_path = config_path , keep_x_past_datasets = keep_x_past_datasets , gpus = gpus )
training_state = TrainingState ( config_path = config_path , keep_x_past_datasets = keep_x_past_datasets , gpus = gpus )
for line in iter ( training_state . process . stdout . readline , " " ) :
for line in iter ( training_state . process . stdout . readline , " " ) :
result , percent , message = training_state . parse ( line = line , verbose = verbose , keep_x_past_datasets = keep_x_past_datasets , progress = progress )
res = training_state . parse ( line = line , verbose = verbose , buffer_size = buffer_size , keep_x_past_datasets = keep_x_past_datasets , progress = progress )
print ( f " [Training] [ { datetime . now ( ) . isoformat ( ) } ] { line [ : - 1 ] } " )
print ( f " [Training] [ { datetime . now ( ) . isoformat ( ) } ] { line [ : - 1 ] } " )
if res :
if res ult :
yield res
yield res ult
if training_state :
if training_state :
training_state . process . stdout . close ( )
training_state . process . stdout . close ( )
@ -824,15 +851,16 @@ def update_training_dataplot(config_path=None):
return update
return update
def reconnect_training ( verbose = False , buffer_size= 8 , progress= gr . Progress ( track_tqdm = True ) ) :
def reconnect_training ( verbose = False , progress= gr . Progress ( track_tqdm = True ) ) :
global training_state
global training_state
if not training_state or not training_state . process :
if not training_state or not training_state . process :
return " Training not in progress "
return " Training not in progress "
for line in iter ( training_state . process . stdout . readline , " " ) :
for line in iter ( training_state . process . stdout . readline , " " ) :
res = training_state . parse ( line = line , verbose = verbose , buffer_size = buffer_size , progress = progress )
result , percent , message = training_state . parse ( line = line , verbose = verbose , keep_x_past_datasets = keep_x_past_datasets , progress = progress )
if res :
print ( f " [Training] [ { datetime . now ( ) . isoformat ( ) } ] { line [ : - 1 ] } " )
yield res
if result :
yield result
def stop_training ( ) :
def stop_training ( ) :
global training_state
global training_state
@ -845,6 +873,7 @@ def stop_training():
training_state . process . send_signal ( signal . SIGINT )
training_state . process . send_signal ( signal . SIGINT )
return_code = training_state . process . wait ( )
return_code = training_state . process . wait ( )
training_state = None
training_state = None
print ( " Killed training process. " )
return f " Training cancelled: { return_code } "
return f " Training cancelled: { return_code } "
def get_halfp_model_path ( ) :
def get_halfp_model_path ( ) :
@ -966,8 +995,18 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
if gradient_accumulation_size == 0 :
if gradient_accumulation_size == 0 :
gradient_accumulation_size = 1
gradient_accumulation_size = 1
if batch_size / gradient_accumulation_size < 2 :
gradient_accumulation_size = int ( batch_size / 2 )
if gradient_accumulation_size == 0 :
gradient_accumulation_size = 1
messages . append ( f " Gradient accumulation size is too large for a given batch size, clamping gradient accumulation size to: { gradient_accumulation_size } " )
elif batch_size % gradient_accumulation_size != 0 :
elif batch_size % gradient_accumulation_size != 0 :
gradient_accumulation_size = int ( batch_size / gradient_accumulation_size )
gradient_accumulation_size = int ( batch_size / gradient_accumulation_size )
if gradient_accumulation_size == 0 :
gradient_accumulation_size = 1
messages . append ( f " Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: { gradient_accumulation_size } " )
messages . append ( f " Batch size is not evenly divisible by the gradient accumulation size, adjusting gradient accumulation size to: { gradient_accumulation_size } " )
iterations = calc_iterations ( epochs = epochs , lines = lines , batch_size = batch_size )
iterations = calc_iterations ( epochs = epochs , lines = lines , batch_size = batch_size )