@ -415,7 +415,7 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress
its = config [ ' train ' ] [ ' niter ' ]
its = config [ ' train ' ] [ ' niter ' ]
checkpoint = 0
checkpoint = 0
checkpoints = config[ ' logger ' ] [ ' save_checkpoint_freq ' ] / its
checkpoints = its / config[ ' logger ' ] [ ' save_checkpoint_freq ' ]
buffer_size = 8
buffer_size = 8
open_state = False
open_state = False
@ -443,40 +443,35 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress
elif progress is not None :
elif progress is not None :
if line . find ( ' 0 % | ' ) == 0 :
if line . find ( ' 0 % | ' ) == 0 :
open_state = True
open_state = True
it_time_start = time . time ( )
elif line . find ( ' 100 % | ' ) == 0 and open_state :
elif line . find ( ' 100 % | ' ) == 0 and open_state :
it_time_end = time . time ( )
open_state = False
open_state = False
it = it + 1
it = it + 1
it_time_end = time . time ( )
it_time_delta = it_time_end - it_time_start
it_time_delta = it_time_end - it_time_start
it_rate = f ' [ { " {:.3f} " . format ( it_time_delta ) } s/it] ' if it_time_delta > = 1 and it_time_delta != 0 else f ' [ { " {:.3f} " . format ( 1 / it_time_delta ) } it/s] ' # I doubt anyone will have it/s rates, but its here
it_time_start = time . time ( )
it_rate = f ' [ { " {:.3f} " . format ( it_time_delta ) } s/it] ' if it_time_delta > = 1 else f ' [ { " {:.3f} " . format ( 1 / it_time_delta ) } it/s] ' # I doubt anyone will have it/s rates, but its here
progress ( it / float ( its ) , f ' [ { it } / { its } ] { it_rate } Training... { status } ' )
progress ( it / float ( its ) , f ' [ { it } / { its } ] { it_rate } Training... { status } ' )
# try because I haven't tested this yet
if line . find ( ' INFO: [epoch: ' ) > = 0 :
try :
# easily rip out our stats...
if line . find ( ' INFO: [epoch: ' ) > = 0 :
match = re . findall ( r ' \ b([a-z_0-9]+?) \ b: ([0-9] \ .[0-9]+?e[+-] \ d+) \ b ' , line )
# easily rip out our stats...
if match and len ( match ) > 0 :
match = re . findall ( r ' \ b([a-z_0-9]+?) \ b: ([0-9] \ .[0-9]+?e[+-] \ d+) \ b ' , line )
for k , v in match :
if match and len ( match ) > 0 :
info [ k ] = float ( v )
for k , v in match :
info [ k ] = float ( v )
# ...and returns our loss rate
# it would be nice for losses to be shown at every step
# ...and returns our loss rate
if ' loss_gpt_total ' in info :
# it would be nice for losses to be shown at every step
status = f " Total loss at step { int ( info [ ' step ' ] ) } : { info [ ' loss_gpt_total ' ] } "
if ' loss_gpt_total ' in info :
elif line . find ( ' Saving models and training states ' ) > = 0 :
status = f " Total loss at step { int ( info [ ' step ' ] ) } : { info [ ' loss_gpt_total ' ] } "
except Exception as e :
pass
if line . find ( ' Saving models and training states ' ) > = 0 :
checkpoint = checkpoint + 1
checkpoint = checkpoint + 1
progress ( checkpoint / float ( checkpoints ) , f ' [ { checkpoint } / { checkpoints } ] Saving checkpoint... ' )
progress ( checkpoint / float ( checkpoints ) , f ' [ { checkpoint } / { checkpoints } ] Saving checkpoint... ' )
print ( f " [Training] [ { datetime . now ( ) . isoformat ( ) } ] { line [ : - 1 ] } " )
print ( f " [Training] [ { datetime . now ( ) . isoformat ( ) } ] { line [ : - 1 ] } " )
if verbose :
if verbose or not training_started :
yield " " . join ( buffer [ - buffer_size : ] )
yield " " . join ( buffer [ - buffer_size : ] )
training_process . stdout . close ( )
training_process . stdout . close ( )