@ -445,9 +445,16 @@ class TrainingState():
with open ( config_path , ' r ' ) as file :
self . config = yaml . safe_load ( file )
self . dataset_path = self . config [ ' datasets ' ] [ ' train ' ] [ ' path ' ]
with open ( self . dataset_path , ' r ' , encoding = " utf-8 " ) as f :
self . dataset_size = len ( f . readlines ( ) )
self . it = 0
self . its = self . config [ ' train ' ] [ ' niter ' ]
self . epoch = 0
self . epochs = int ( self . its / self . dataset_size )
self . checkpoint = 0
self . checkpoints = int ( self . its / self . config [ ' logger ' ] [ ' save_checkpoint_freq ' ] )
@ -459,10 +466,11 @@ class TrainingState():
self . info = { }
self . status = " "
self . it _rate = " "
self . it _time_start = 0
self . it _time_end = 0
self . epoch _rate = " "
self . epoch _time_start = 0
self . epoch _time_end = 0
self . eta = " ? "
self . eta_hhmmss = " ? "
print ( " Spawning process: " , " " . join ( self . cmd ) )
self . process = subprocess . Popen ( self . cmd , stdout = subprocess . PIPE , stderr = subprocess . STDOUT , universal_newlines = True )
@ -473,27 +481,30 @@ class TrainingState():
# rip out iteration info
if not self . training_started :
if line . find ( ' Start training from epoch ' ) > = 0 :
self . it _time_start = time . time ( )
self . epoch _time_start = time . time ( )
self . training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
match = re . findall ( r ' epoch: ([ \ d,]+) ' , line )
if match and len ( match ) > 0 :
self . epoch = int ( match [ 0 ] . replace ( " , " , " " ) )
match = re . findall ( r ' iter: ([ \ d,]+) ' , line )
if match and len ( match ) > 0 :
self . it = int ( match [ 0 ] . replace ( " , " , " " ) )
elif progress is not None :
if line . find ( ' 0% | ' ) == 0 :
if line . find ( ' %| ' ) > 0 and not self . open_state :
self . open_state = True
elif line . find ( ' 100 % | ' ) == 0 and self . open_state :
self . open_state = False
self . it = self . it + 1
self . epoch = self . epoch + 1
self . it _time_end = time . time ( )
self . it_time_delta = self . it_time_end - self . it _time_start
self . it _time_start = time . time ( )
self . it _rate = f ' [ { " {:.3f} " . format ( self . it_time_delta) } s/it] ' if self . it _time_delta > = 1 else f ' [ { " {:.3f} " . format ( 1 / self . it_time_delta) } it /s]' # I doubt anyone will have it/s rates, but its here
self . eta = ( self . its - self . it ) * self . it _time_delta
self . epoch _time_end = time . time ( )
self . epoch_time_delta = self . epoch_time_end - self . epoch _time_start
self . epoch _time_start = time . time ( )
self . epoch _rate = f ' [ { " {:.3f} " . format ( self . epoch_time_delta) } s/epoch] ' if self . epoch _time_delta > = 1 else f ' [ { " {:.3f} " . format ( 1 / self . epoch_time_delta) } epoch /s]' # I doubt anyone will have it/s rates, but its here
self . eta = ( self . epochs - self . epoch ) * self . epoch _time_delta
self . eta_hhmmss = str ( timedelta ( seconds = int ( self . eta ) ) )
progress ( self . it / float ( self . its ) , f ' [ { self . it } / { self . it s} ] [ETA: { self . eta_hhmmss } ] { self . it _rate} Training... { self . status } ' )
progress ( self . epoch / float ( self . epochs ) , f ' [ { self . epoch } / { self . epoch s} ] [ETA: { self . eta_hhmmss } ] { self . epoch _rate} Training... { self . status } ' )
if line . find ( ' INFO: [epoch: ' ) > = 0 :
# easily rip out our stats...
@ -502,11 +513,8 @@ class TrainingState():
for k , v in match :
self . info [ k ] = float ( v )
# ...and returns our loss rate
# it would be nice for losses to be shown at every step
if ' loss_gpt_total ' in self . info :
# self.info['step'] returns the steps, not iterations, so we won't even bother ripping the reported step count, as iteration count won't get ripped from the regex
self . status = f " Total loss at iteration { self . it } : { self . info [ ' loss_gpt_total ' ] } "
self . status = f " Total loss at epoch { self . epoch } : { self . info [ ' loss_gpt_total ' ] } "
elif line . find ( ' Saving models and training states ' ) > = 0 :
self . checkpoint = self . checkpoint + 1
progress ( self . checkpoint / float ( self . checkpoints ) , f ' [ { self . checkpoint } / { self . checkpoints } ] Saving checkpoint... ' )