@ -550,8 +550,11 @@ class TrainingState():
self . eta_hhmmss = " ? "
self . eta_hhmmss = " ? "
self . last_info_check_at = 0
self . last_info_check_at = 0
self . statistics = [ ]
self . losses = [ ]
self . losses = [ ]
self . loss_milestones = [ 1.0 , 0.15 , 0.05 ]
self . load_losses ( )
self . load_losses ( )
if keep_x_past_datasets > 0 :
if keep_x_past_datasets > 0 :
self . cleanup_old ( keep = keep_x_past_datasets )
self . cleanup_old ( keep = keep_x_past_datasets )
@ -578,7 +581,7 @@ class TrainingState():
highest_step = self . last_info_check_at
highest_step = self . last_info_check_at
if not update :
if not update :
self . losse s = [ ]
self . statistic s = [ ]
if use_tensorboard :
if use_tensorboard :
logs = sorted ( [ f ' { self . dataset_dir } /tb_logger/ { d } ' for d in os . listdir ( f ' { self . dataset_dir } /tb_logger/ ' ) if d [ : 6 ] == " events " ] )
logs = sorted ( [ f ' { self . dataset_dir } /tb_logger/ { d } ' for d in os . listdir ( f ' { self . dataset_dir } /tb_logger/ ' ) if d [ : 6 ] == " events " ] )
@ -596,9 +599,12 @@ class TrainingState():
if update and s . step < = self . last_info_check_at :
if update and s . step < = self . last_info_check_at :
continue
continue
highest_step = max ( highest_step , s . step )
highest_step = max ( highest_step , s . step )
self . statistics . append ( { " step " : s . step , " value " : s . value , " type " : key } )
if key == ' loss_gpt_total ' :
self . losses . append ( { " step " : s . step , " value " : s . value , " type " : key } )
self . losses . append ( { " step " : s . step , " value " : s . value , " type " : key } )
except Exception as e :
except Exception as e :
print ( " Failed to parse event log: " , log )
pass
pass
else :
else :
@ -630,6 +636,9 @@ class TrainingState():
if update and int ( k ) < = self . last_info_check_at :
if update and int ( k ) < = self . last_info_check_at :
continue
continue
highest_step = max ( highest_step , s . step )
highest_step = max ( highest_step , s . step )
self . statistics . append ( { " step " : int ( k ) , " value " : infos [ k ] [ key ] , " type " : key } )
if key == " loss_gpt_total " :
self . losses . append ( { " step " : int ( k ) , " value " : infos [ k ] [ key ] , " type " : key } )
self . losses . append ( { " step " : int ( k ) , " value " : infos [ k ] [ key ] , " type " : key } )
self . last_info_check_at = highest_step
self . last_info_check_at = highest_step
@ -739,20 +748,36 @@ class TrainingState():
metric_loss . append ( f ' Loss: { " {:3f} " . format ( self . losses [ - 1 ] [ " value " ] ) } ' )
metric_loss . append ( f ' Loss: { " {:3f} " . format ( self . losses [ - 1 ] [ " value " ] ) } ' )
if len ( self . losses ) > = 2 :
if len ( self . losses ) > = 2 :
delta_loss = self . losses [ - 2 ] [ " value " ] - self . losses [ - 1 ] [ " value " ]
# i can probably do a """riemann sum""" to get a better derivative, but the instantaneous one works fine
delta_step = self . losses [ - 2 ] [ " step " ] - self . losses [ - 1 ] [ " step " ]
d1_loss = self . losses [ - 1 ] [ " value " ]
d2_loss = self . losses [ - 2 ] [ " value " ]
inst_deriv = delta_loss / delta_step
dloss = d2_loss - d1_loss
est_loss = delta_loss + ( self . its - self . it ) * inst_deriv
metric_loss . append ( f ' Est. Final Loss: { " {:3f} " . format ( est_loss ) } ' )
d1_step = self . losses [ - 1 ] [ " step " ]
d2_step = self . losses [ - 2 ] [ " step " ]
print ( delta_loss , delta_step , inst_deriv , est_loss )
dstep = d2_step - d1_step
# don't bother if the loss went up
if dloss < 0 :
its_remain = self . its - self . it
inst_deriv = dloss / dstep
next_milestone = None
for milestone in self . loss_milestones :
if d1_loss > milestone :
next_milestone = milestone
break
if next_milestone :
# tfw can do simple calculus but not basic algebra in my head
est_its = ( next_milestone - d1_loss ) * ( dstep / dloss )
metric_loss . append ( f ' Est. milestone { next_milestone } in: { int ( est_its ) } its ' )
else :
est_loss = inst_deriv * its_remain + d1_loss
metric_loss . append ( f ' Est. final loss: { " {:3f} " . format ( est_loss ) } ' )
metric_loss = " , " . join ( metric_loss )
metric_loss = " , " . join ( metric_loss )
message = f ' [ { metric_step } ] [ { metric_rate } ] [ETA: { eta_hhmmss } ] [ { metric_loss } ] '
message = f ' [ { metric_step } ] [ { metric_rate } ] [ETA: { eta_hhmmss } ] [ { metric_loss } ] '
if lapsed :
if lapsed :
@ -859,9 +884,9 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro
def get_training_losses ( ) :
def get_training_losses ( ) :
global training_state
global training_state
if not training_state or not training_state . losse s:
if not training_state or not training_state . statistic s:
return
return
return pd . DataFrame ( training_state . losse s)
return pd . DataFrame ( training_state . statistic s)
def update_training_dataplot ( config_path = None ) :
def update_training_dataplot ( config_path = None ) :
global training_state
global training_state
@ -870,13 +895,13 @@ def update_training_dataplot(config_path=None):
if not training_state :
if not training_state :
if config_path :
if config_path :
training_state = TrainingState ( config_path = config_path , start = False )
training_state = TrainingState ( config_path = config_path , start = False )
if training_state . losse s:
if training_state . statistic s:
update = gr . LinePlot . update ( value = pd . DataFrame ( training_state . losse s) )
update = gr . LinePlot . update ( value = pd . DataFrame ( training_state . statistic s) )
del training_state
del training_state
training_state = None
training_state = None
elif training_state . losse s:
elif training_state . statistic s:
training_state . load_losses ( )
training_state . load_losses ( )
update = gr . LinePlot . update ( value = pd . DataFrame ( training_state . losse s) )
update = gr . LinePlot . update ( value = pd . DataFrame ( training_state . statistic s) )
return update
return update