@ -470,14 +470,21 @@ class TrainingState():
self . epoch_rate = " "
self . epoch_time_start = 0
self . epoch_time_end = 0
self . it_rate = " "
self . it_time_start = 0
self . it_time_end = 0
self . last_step = 0
self . eta = " ? "
self . eta_hhmmss = " ? "
print ( " Spawning process: " , " " . join ( self . cmd ) )
self . process = subprocess . Popen ( self . cmd , stdout = subprocess . PIPE , stderr = subprocess . STDOUT , universal_newlines = True )
def parse ( self , line , verbose = False , buffer_size = 8 , progress = None ) :
self . buffer . append ( f ' { line } ' )
def parse ( self , line , verbose = False , buffer_size = 8 , progress = None , owner = True ) :
if owner :
self . buffer . append ( f ' { line } ' )
# rip out iteration info
if not self . training_started :
@ -492,47 +499,97 @@ class TrainingState():
if match and len ( match ) > 0 :
self . it = int ( match [ 0 ] . replace ( " , " , " " ) )
else :
if line . find ( ' % | ' ) > 0 and not self . open_state :
self . open_state = True
elif line . find ( ' 100 % | ' ) == 0 and self . open_state :
self . open_state = False
self . epoch = self . epoch + 1
lapsed = line . find ( ' 100 % | ' ) == 0 and self . open_state
self . epoch_time_end = time . time ( )
self . epoch_time_delta = self . epoch_time_end - self . epoch_time_start
self . epoch_time_start = time . time ( )
self . epoch_rate = f ' [ { " {:.3f} " . format ( self . epoch_time_delta ) } s/epoch] ' if self . epoch_time_delta > = 1 else f ' [ { " {:.3f} " . format ( 1 / self . epoch_time_delta ) } epoch/s] ' # I doubt anyone will have it/s rates, but its here
self . eta = ( self . epochs - self . epoch ) * self . epoch_time_delta
self . eta_hhmmss = str ( timedelta ( seconds = int ( self . eta ) ) )
if line . find ( ' % | ' ) > 0 :
match = re . findall ( r ' +?( \ d+) % \ |(.+?) \ | ( \ d+| \ ?) \ /( \ d+| \ ?) \ [(.+?)<(.+?), +(.+?) \ ] ' , line )
if match and len ( match ) > 0 :
match = match [ 0 ]
percent = int ( match [ 0 ] ) / 100.0
progressbar = match [ 1 ]
step = int ( match [ 2 ] )
steps = int ( match [ 3 ] )
elapsed = match [ 4 ]
until = match [ 5 ]
rate = match [ 6 ]
epoch_percent = self . epoch / float ( self . epochs )
if owner :
last_step = self . last_step
self . last_step = step
if last_step < step :
self . it = self . it + ( step - last_step )
if last_step > step and step == 0 :
lapsed = True
self . it_time_end = time . time ( )
self . it_time_delta = self . it_time_end - self . it_time_start
self . it_time_start = time . time ( )
self . it_rate = f ' [ { " {:.3f} " . format ( self . it_time_delta ) } s/it] ' if self . it_time_delta > = 1 else f ' [ { " {:.3f} " . format ( 1 / self . it_time_delta ) } it/s] '
self . eta = ( self . its - self . it ) * self . it_time_delta
self . eta_hhmmss = str ( timedelta ( seconds = int ( self . eta ) ) )
message = f ' [ { self . epoch } / { self . epochs } ] [ { self . it } / { self . its } ] [ETA: { self . eta_hhmmss } ] { self . epoch_rate } / { self . it_rate } { self . status } '
if progress is not None :
progress ( epoch_percent , message )
if owner :
# print(f'{"{:.3f}".format(percent*100)}% {message}')
self . buffer . append ( f ' [ { " {:.3f} " . format ( epoch_percent * 100 ) } % / { " {:.3f} " . format ( percent * 100 ) } %] { message } ' )
if line . find ( ' % | ' ) > 0 and not self . open_state :
if owner :
self . open_state = True
elif lapsed :
if owner :
self . open_state = False
self . epoch = self . epoch + 1
self . it = int ( self . epoch * ( self . dataset_size / self . batch_size ) )
self . epoch_time_end = time . time ( )
self . epoch_time_delta = self . epoch_time_end - self . epoch_time_start
self . epoch_time_start = time . time ( )
self . epoch_rate = f ' [ { " {:.3f} " . format ( self . epoch_time_delta ) } s/epoch] ' if self . epoch_time_delta > = 1 else f ' [ { " {:.3f} " . format ( 1 / self . epoch_time_delta ) } epoch/s] ' # I doubt anyone will have it/s rates, but its here
self . eta = ( self . epochs - self . epoch ) * self . epoch_time_delta
self . eta_hhmmss = str ( timedelta ( seconds = int ( self . eta ) ) )
percent = self . epoch / float ( self . epochs )
message = f ' [ { self . epoch } / { self . epochs } ] [ETA: { self . eta_hhmmss } ] { self . epoch_rate } { self . status } '
print ( f ' { " {:.3f} " . format ( percent * 100 ) } % { message } ' )
message = f ' [ { self . epoch } / { self . epochs } ] [ { self . it } / { self . its } ] [ ETA: { self . eta_hhmmss } ] { self . epoch _rate} / { self . it _rate} { self . status } '
if progress is not None :
progress ( percent , message )
self . buffer . append ( f ' { " {:.3f} " . format ( percent * 100 ) } % { message } ' )
if owner :
print ( f ' { " {:.3f} " . format ( percent * 100 ) } % { message } ' )
self . buffer . append ( f ' { " {:.3f} " . format ( percent * 100 ) } % { message } ' )
if line . find ( ' INFO: [epoch: ' ) > = 0 :
# easily rip out our stats...
match = re . findall ( r ' \ b([a-z_0-9]+?) \ b: ([0-9] \ .[0-9]+?e[+-] \ d+) \ b ' , line )
if match and len ( match ) > 0 :
for k , v in match :
self . info [ k ] = float ( v )
if ' loss_gpt_total ' in self . info :
self . status = f " Total loss at epoch { self . epoch } : { self . info [ ' loss_gpt_total ' ] } "
print ( self . status )
self . buffer . append ( self . status )
if owner :
# easily rip out our stats...
match = re . findall ( r ' \ b([a-z_0-9]+?) \ b: ([0-9] \ .[0-9]+?e[+-] \ d+) \ b ' , line )
if match and len ( match ) > 0 :
for k , v in match :
self . info [ k ] = float ( v )
if ' loss_gpt_total ' in self . info :
self . status = f " Total loss at epoch { self . epoch } : { self . info [ ' loss_gpt_total ' ] } "
print ( self . status )
self . buffer . append ( self . status )
elif line . find ( ' Saving models and training states ' ) > = 0 :
self . checkpoint = self . checkpoint + 1
if owner :
self . checkpoint = self . checkpoint + 1
percent = self . checkpoint / float ( self . checkpoints )
message = f ' [ { self . checkpoint } / { self . checkpoints } ] Saving checkpoint... '
print ( f ' { " {:.3f} " . format ( percent * 100 ) } % { message } ' )
if progress is not None :
progress ( percent , message )
self . buffer . append ( f ' { " {:.3f} " . format ( percent * 100 ) } % { message } ' )
if owner :
print ( f ' { " {:.3f} " . format ( percent * 100 ) } % { message } ' )
self . buffer . append ( f ' { " {:.3f} " . format ( percent * 100 ) } % { message } ' )
self . buffer = self . buffer [ - buffer_size : ]
if owner :
self . buffer = self . buffer [ - buffer_size : ]
if verbose or not self . training_started :
return " " . join ( self . buffer )
@ -552,7 +609,7 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress
for line in iter ( training_state . process . stdout . readline , " " ) :
res = training_state . parse ( line = line , verbose = verbose , buffer_size = buffer_size , progress = progress )
res = training_state . parse ( line = line , verbose = verbose , buffer_size = buffer_size , progress = progress , owner = True )
print ( f " [Training] [ { datetime . now ( ) . isoformat ( ) } ] { line [ : - 1 ] } " )
if res :
yield res
@ -565,13 +622,13 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress
#if return_code:
# raise subprocess.CalledProcessError(return_code, cmd)
def reconnect_training ( config_path, verbose= False , buffer_size = 8 , progress = gr . Progress ( track_tqdm = True ) ) :
def reconnect_training ( verbose= False , buffer_size = 8 , progress = gr . Progress ( track_tqdm = True ) ) :
global training_state
if not training_state or not training_state . process :
return " Training not in progress "
for line in iter ( training_state . process . stdout . readline , " " ) :
res = training_state . parse ( line = line , verbose = verbose , buffer_size = buffer_size , progress = progress )
res = training_state . parse ( line = line , verbose = verbose , buffer_size = buffer_size , progress = progress , owner = True )
if res :
yield res