@ -552,6 +552,11 @@ class TrainingState():
self . last_info_check_at = 0
self . last_info_check_at = 0
self . statistics = [ ]
self . statistics = [ ]
self . losses = [ ]
self . losses = [ ]
self . metrics = {
' step ' : " " ,
' rate ' : " " ,
' loss ' : " " ,
}
self . loss_milestones = [ 1.0 , 0.15 , 0.05 ]
self . loss_milestones = [ 1.0 , 0.15 , 0.05 ]
@ -691,7 +696,37 @@ class TrainingState():
lapsed = False
lapsed = False
message = None
message = None
if line . find ( ' % | ' ) > 0 :
if line . find ( ' INFO: [epoch: ' ) > = 0 :
# to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point
if ' : nan ' in line :
should_return = True
print ( " ! NAN DETECTED ! " )
self . buffer . append ( " ! NAN DETECTED ! " )
# easily rip out our stats...
match = re . findall ( r ' \ b([a-z_0-9]+?) \ b: +?([0-9] \ .[0-9]+?e[+-] \ d+|[ \ d,]+) \ b ' , line )
if match and len ( match ) > 0 :
for k , v in match :
self . info [ k ] = float ( v . replace ( " , " , " " ) )
self . load_losses ( update = True )
should_return = True
elif line . find ( ' Saving models and training states ' ) > = 0 :
self . checkpoint = self . checkpoint + 1
percent = self . checkpoint / float ( self . checkpoints )
message = f ' [ { self . checkpoint } / { self . checkpoints } ] Saving checkpoint... '
if progress is not None :
progress ( percent , message )
print ( f ' { " {:.3f} " . format ( percent * 100 ) } % { message } ' )
self . buffer . append ( f ' { " {:.3f} " . format ( percent * 100 ) } % { message } ' )
self . cleanup_old ( keep = keep_x_past_datasets )
elif line . find ( ' % | ' ) > 0 :
match = re . findall ( r ' ( \ d+) % \ |(.+?) \ | ( \ d+| \ ?) \ /( \ d+| \ ?) \ [(.+?)<(.+?), +(.+?) \ ] ' , line )
match = re . findall ( r ' ( \ d+) % \ |(.+?) \ | ( \ d+| \ ?) \ /( \ d+| \ ?) \ [(.+?)<(.+?), +(.+?) \ ] ' , line )
if match and len ( match ) > 0 :
if match and len ( match ) > 0 :
match = match [ 0 ]
match = match [ 0 ]
@ -722,15 +757,34 @@ class TrainingState():
except Exception as e :
except Exception as e :
pass
pass
metric_step = [ f " { self . epoch } / { self . epochs } " , f " { self . it } / { self . its } " , f " { step } / { steps } " ]
self . metrics [ ' step ' ] = [ f " { self . epoch } / { self . epochs } " , f " { self . it } / { self . its } " , f " { step } / { steps } " ]
metric_step = " , " . join ( metric_step )
self . metrics [ ' step ' ] = " , " . join ( self . metrics [ ' step ' ] )
if lapsed :
self . epoch = self . epoch + 1
self . it = int ( self . epoch * ( self . dataset_size / self . batch_size ) )
self . epoch_time_end = time . time ( )
self . epoch_time_delta = self . epoch_time_end - self . epoch_time_start
self . epoch_time_start = time . time ( )
self . epoch_rate = f ' { " {:.3f} " . format ( self . epoch_time_delta ) } s/epoch ' if self . epoch_time_delta > = 1 else f ' { " {:.3f} " . format ( 1 / self . epoch_time_delta ) } epoch/s ' # I doubt anyone will have it/s rates, but its here
#self.eta = (self.epochs - self.epoch) * self.epoch_time_delta
self . epoch_time_deltas = self . epoch_time_deltas + self . epoch_time_delta
self . epoch_taken = self . epoch_taken + 1
self . eta = ( self . epochs - self . epoch ) * ( self . epoch_time_deltas / self . epoch_taken )
try :
eta = str ( timedelta ( seconds = int ( self . eta ) ) )
self . eta_hhmmss = eta
except Exception as e :
pass
metric_rate = [ ]
self . metrics [ ' rate ' ] = [ ]
if self . epoch_rate :
if self . epoch_rate :
metric_rate . append ( self . epoch_rate )
self . metrics [ ' rate ' ] . append ( self . epoch_rate )
if self . it_rate :
if self . it_rate :
metric_rate . append ( self . it_rate )
self . metrics [ ' rate ' ] . append ( self . it_rate )
metric_rate = " , " . join ( metric_rate )
self . metrics [ ' rate ' ] = " , " . join ( self . metrics [ ' rate ' ] )
eta_hhmmss = " ? "
eta_hhmmss = " ? "
if self . eta_hhmmss :
if self . eta_hhmmss :
@ -743,9 +797,9 @@ class TrainingState():
except Exception as e :
except Exception as e :
pass
pass
metric_loss = [ ]
self . metrics [ ' loss ' ] = [ ]
if len ( self . losses ) > 0 :
if len ( self . losses ) > 0 :
metric_loss . append ( f ' Loss: { " {:3f} " . format ( self . losses [ - 1 ] [ " value " ] ) } ' )
self . metrics [ ' loss ' ] . append ( f ' Loss: { " {:3f} " . format ( self . losses [ - 1 ] [ " value " ] ) } ' )
if len ( self . losses ) > = 2 :
if len ( self . losses ) > = 2 :
# i can probably do a """riemann sum""" to get a better derivative, but the instantaneous one works fine
# i can probably do a """riemann sum""" to get a better derivative, but the instantaneous one works fine
@ -771,33 +825,14 @@ class TrainingState():
if next_milestone :
if next_milestone :
# tfw can do simple calculus but not basic algebra in my head
# tfw can do simple calculus but not basic algebra in my head
est_its = ( next_milestone - d1_loss ) * ( dstep / dloss )
est_its = ( next_milestone - d1_loss ) * ( dstep / dloss )
metric_loss . append ( f ' Est. milestone { next_milestone } in: { int ( est_its ) } its ' )
self . metrics [ ' loss ' ] . append ( f ' Est. milestone { next_milestone } in: { int ( est_its ) } its ' )
else :
else :
est_loss = inst_deriv * its_remain + d1_loss
est_loss = inst_deriv * its_remain + d1_loss
metric_loss . append ( f ' Est. final loss: { " {:3f} " . format ( est_loss ) } ' )
self . metrics [ ' loss ' ] . append ( f ' Est. final loss: { " {:3f} " . format ( est_loss ) } ' )
metric_loss = " , " . join ( metric_loss )
self . metrics [ ' loss ' ] = " , " . join ( self . metrics [ ' loss ' ] )
message = f ' [ { metric_step } ] [ { metric_rate } ] [ETA: { eta_hhmmss } ] [ { metric_loss } ] '
message = f " [ { self . metrics [ ' step ' ] } ] [ { self . metrics [ ' rate ' ] } ] [ETA: { eta_hhmmss } ] [ { self . metrics [ ' loss ' ] } ] "
if lapsed :
self . epoch = self . epoch + 1
self . it = int ( self . epoch * ( self . dataset_size / self . batch_size ) )
self . epoch_time_end = time . time ( )
self . epoch_time_delta = self . epoch_time_end - self . epoch_time_start
self . epoch_time_start = time . time ( )
self . epoch_rate = f ' { " {:.3f} " . format ( self . epoch_time_delta ) } s/epoch ' if self . epoch_time_delta > = 1 else f ' { " {:.3f} " . format ( 1 / self . epoch_time_delta ) } epoch/s ' # I doubt anyone will have it/s rates, but its here
#self.eta = (self.epochs - self.epoch) * self.epoch_time_delta
self . epoch_time_deltas = self . epoch_time_deltas + self . epoch_time_delta
self . epoch_taken = self . epoch_taken + 1
self . eta = ( self . epochs - self . epoch ) * ( self . epoch_time_deltas / self . epoch_taken )
try :
eta = str ( timedelta ( seconds = int ( self . eta ) ) )
self . eta_hhmmss = eta
except Exception as e :
pass
if message :
if message :
percent = self . it / float ( self . its ) # self.epoch / float(self.epochs)
percent = self . it / float ( self . its ) # self.epoch / float(self.epochs)
@ -806,36 +841,6 @@ class TrainingState():
self . buffer . append ( f ' [ { " {:.3f} " . format ( percent * 100 ) } %] { message } ' )
self . buffer . append ( f ' [ { " {:.3f} " . format ( percent * 100 ) } %] { message } ' )
if line . find ( ' INFO: [epoch: ' ) > = 0 :
# to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point
if ' : nan ' in line :
should_return = True
print ( " ! NAN DETECTED ! " )
self . buffer . append ( " ! NAN DETECTED ! " )
# easily rip out our stats...
match = re . findall ( r ' \ b([a-z_0-9]+?) \ b: +?([0-9] \ .[0-9]+?e[+-] \ d+|[ \ d,]+) \ b ' , line )
if match and len ( match ) > 0 :
for k , v in match :
self . info [ k ] = float ( v . replace ( " , " , " " ) )
self . load_losses ( update = True )
should_return = True
elif line . find ( ' Saving models and training states ' ) > = 0 :
self . checkpoint = self . checkpoint + 1
percent = self . checkpoint / float ( self . checkpoints )
message = f ' [ { self . checkpoint } / { self . checkpoints } ] Saving checkpoint... '
if progress is not None :
progress ( percent , message )
print ( f ' { " {:.3f} " . format ( percent * 100 ) } % { message } ' )
self . buffer . append ( f ' { " {:.3f} " . format ( percent * 100 ) } % { message } ' )
self . cleanup_old ( keep = keep_x_past_datasets )
if verbose and not self . training_started :
if verbose and not self . training_started :
should_return = True
should_return = True