fixed the brain worm discrepancy between epochs, iterations, and steps
This commit is contained in:
parent
1cbcf14cff
commit
487f2ebf32
42
src/utils.py
42
src/utils.py
|
@ -445,9 +445,16 @@ class TrainingState():
|
|||
with open(config_path, 'r') as file:
|
||||
self.config = yaml.safe_load(file)
|
||||
|
||||
self.dataset_path = self.config['datasets']['train']['path']
|
||||
with open(self.dataset_path, 'r', encoding="utf-8") as f:
|
||||
self.dataset_size = len(f.readlines())
|
||||
|
||||
self.it = 0
|
||||
self.its = self.config['train']['niter']
|
||||
|
||||
self.epoch = 0
|
||||
self.epochs = int(self.its/self.dataset_size)
|
||||
|
||||
self.checkpoint = 0
|
||||
self.checkpoints = int(self.its / self.config['logger']['save_checkpoint_freq'])
|
||||
|
||||
|
@ -459,10 +466,11 @@ class TrainingState():
|
|||
self.info = {}
|
||||
self.status = ""
|
||||
|
||||
self.it_rate = ""
|
||||
self.it_time_start = 0
|
||||
self.it_time_end = 0
|
||||
self.epoch_rate = ""
|
||||
self.epoch_time_start = 0
|
||||
self.epoch_time_end = 0
|
||||
self.eta = "?"
|
||||
self.eta_hhmmss = "?"
|
||||
|
||||
print("Spawning process: ", " ".join(self.cmd))
|
||||
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||
|
@ -473,27 +481,30 @@ class TrainingState():
|
|||
# rip out iteration info
|
||||
if not self.training_started:
|
||||
if line.find('Start training from epoch') >= 0:
|
||||
self.it_time_start = time.time()
|
||||
self.epoch_time_start = time.time()
|
||||
self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
|
||||
|
||||
match = re.findall(r'epoch: ([\d,]+)', line)
|
||||
if match and len(match) > 0:
|
||||
self.epoch = int(match[0].replace(",", ""))
|
||||
match = re.findall(r'iter: ([\d,]+)', line)
|
||||
if match and len(match) > 0:
|
||||
self.it = int(match[0].replace(",", ""))
|
||||
elif progress is not None:
|
||||
if line.find(' 0%|') == 0:
|
||||
if line.find('%|') > 0 and not self.open_state:
|
||||
self.open_state = True
|
||||
elif line.find('100%|') == 0 and self.open_state:
|
||||
self.open_state = False
|
||||
self.it = self.it + 1
|
||||
self.epoch = self.epoch + 1
|
||||
|
||||
self.it_time_end = time.time()
|
||||
self.it_time_delta = self.it_time_end-self.it_time_start
|
||||
self.it_time_start = time.time()
|
||||
self.it_rate = f'[{"{:.3f}".format(self.it_time_delta)}s/it]' if self.it_time_delta >= 1 else f'[{"{:.3f}".format(1/self.it_time_delta)}it/s]' # I doubt anyone will have it/s rates, but its here
|
||||
self.eta = (self.its - self.it) * self.it_time_delta
|
||||
self.epoch_time_end = time.time()
|
||||
self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start
|
||||
self.epoch_time_start = time.time()
|
||||
self.epoch_rate = f'[{"{:.3f}".format(self.epoch_time_delta)}s/epoch]' if self.epoch_time_delta >= 1 else f'[{"{:.3f}".format(1/self.epoch_time_delta)}epoch/s]' # I doubt anyone will have it/s rates, but its here
|
||||
self.eta = (self.epochs - self.epoch) * self.epoch_time_delta
|
||||
self.eta_hhmmss = str(timedelta(seconds=int(self.eta)))
|
||||
|
||||
progress(self.it / float(self.its), f'[{self.it}/{self.its}] [ETA: {self.eta_hhmmss}] {self.it_rate} Training... {self.status}')
|
||||
progress(self.epoch / float(self.epochs), f'[{self.epoch}/{self.epochs}] [ETA: {self.eta_hhmmss}] {self.epoch_rate} Training... {self.status}')
|
||||
|
||||
if line.find('INFO: [epoch:') >= 0:
|
||||
# easily rip out our stats...
|
||||
|
@ -501,12 +512,9 @@ class TrainingState():
|
|||
if match and len(match) > 0:
|
||||
for k, v in match:
|
||||
self.info[k] = float(v)
|
||||
|
||||
# ...and returns our loss rate
|
||||
# it would be nice for losses to be shown at every step
|
||||
|
||||
if 'loss_gpt_total' in self.info:
|
||||
# self.info['step'] returns the steps, not iterations, so we won't even bother ripping the reported step count, as iteration count won't get ripped from the regex
|
||||
self.status = f"Total loss at iteration {self.it}: {self.info['loss_gpt_total']}"
|
||||
self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
|
||||
elif line.find('Saving models and training states') >= 0:
|
||||
self.checkpoint = self.checkpoint + 1
|
||||
progress(self.checkpoint / float(self.checkpoints), f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...')
|
||||
|
|
Loading…
Reference in New Issue
Block a user