From aafeb9f96a3fe39cc3ff74a02c509cd78c892d70 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 25 Feb 2023 16:44:25 +0000 Subject: [PATCH] actually fixed the training output text parser --- src/utils.py | 146 +++++++++++++++++++++++++++++---------------------- 1 file changed, 82 insertions(+), 64 deletions(-) diff --git a/src/utils.py b/src/utils.py index 3fc3776..1dadb87 100755 --- a/src/utils.py +++ b/src/utils.py @@ -470,10 +470,14 @@ class TrainingState(): self.epoch_rate = "" self.epoch_time_start = 0 self.epoch_time_end = 0 + self.epoch_time_deltas = 0 + self.epoch_taken = 0 self.it_rate = "" self.it_time_start = 0 self.it_time_end = 0 + self.it_time_deltas = 0 + self.it_taken = 0 self.last_step = 0 self.eta = "?" @@ -482,9 +486,8 @@ class TrainingState(): print("Spawning process: ", " ".join(self.cmd)) self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) - def parse(self, line, verbose=False, buffer_size=8, progress=None, owner=True): - if owner: - self.buffer.append(f'{line}') + def parse(self, line, verbose=False, buffer_size=8, progress=None ): + self.buffer.append(f'{line}') # rip out iteration info if not self.training_started: @@ -499,10 +502,10 @@ class TrainingState(): if match and len(match) > 0: self.it = int(match[0].replace(",", "")) else: - lapsed = line.find('100%|') == 0 and self.open_state + lapsed = False if line.find('%|') > 0: - match = re.findall(r' +?(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line) + match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line) if match and len(match) > 0: match = match[0] percent = int(match[0])/100.0 @@ -513,47 +516,65 @@ class TrainingState(): until = match[5] rate = match[6] - epoch_percent = self.epoch / float(self.epochs) - - if owner: - last_step = self.last_step - self.last_step = step - if last_step < step: - self.it = self.it + (step - last_step) - - if last_step > step and step == 0: - lapsed = True - - self.it_time_end = time.time() - self.it_time_delta = self.it_time_end-self.it_time_start - self.it_time_start = time.time() - self.it_rate = f'[{"{:.3f}".format(self.it_time_delta)}s/it]' if self.it_time_delta >= 1 else f'[{"{:.3f}".format(1/self.it_time_delta)}it/s]' - - self.eta = (self.its - self.it) * self.it_time_delta - self.eta_hhmmss = str(timedelta(seconds=int(self.eta))) + epoch_percent = self.it / float(self.its) # self.epoch / float(self.epochs) + + last_step = self.last_step + self.last_step = step + if last_step < step: + self.it = self.it + (step - last_step) + + if last_step == step and step == steps: + lapsed = True + + self.it_time_end = time.time() + self.it_time_delta = self.it_time_end-self.it_time_start + self.it_time_start = time.time() + try: + rate = f'[{"{:.3f}".format(self.it_time_delta)}s/it]' if self.it_time_delta >= 1 else f'[{"{:.3f}".format(1/self.it_time_delta)}it/s]' + self.it_rate = rate + except Exception as e: + pass + + """ + # I wanted frequently updated ETA, but I can't wrap my noggin around getting it to work on an empty belly + # will fix later + + #self.eta = (self.its - self.it) * self.it_time_delta + self.it_time_deltas = self.it_time_deltas + self.it_time_delta + self.it_taken = self.it_taken + 1 + self.eta = (self.its - self.it) * (self.it_time_deltas / self.it_taken) + try: + eta = str(timedelta(seconds=int(self.eta))) + self.eta_hhmmss = eta + except Exception as e: + pass + """ message = f'[{self.epoch}/{self.epochs}] [{self.it}/{self.its}] [ETA: {self.eta_hhmmss}] {self.epoch_rate} / {self.it_rate} {self.status}' if progress is not None: progress(epoch_percent, message) - if owner: - # print(f'{"{:.3f}".format(percent*100)}% {message}') - self.buffer.append(f'[{"{:.3f}".format(epoch_percent*100)}% / {"{:.3f}".format(percent*100)}%] {message}') - - if line.find('%|') > 0 and not self.open_state: - if owner: - self.open_state = True - elif lapsed and self.open_state: - if owner: - self.open_state = False - self.epoch = self.epoch + 1 - self.it = int(self.epoch * (self.dataset_size / self.batch_size)) - - self.epoch_time_end = time.time() - self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start - self.epoch_time_start = time.time() - self.epoch_rate = f'[{"{:.3f}".format(self.epoch_time_delta)}s/epoch]' if self.epoch_time_delta >= 1 else f'[{"{:.3f}".format(1/self.epoch_time_delta)}epoch/s]' # I doubt anyone will have it/s rates, but its here - self.eta = (self.epochs - self.epoch) * self.epoch_time_delta - self.eta_hhmmss = str(timedelta(seconds=int(self.eta))) + + # print(f'{"{:.3f}".format(percent*100)}% {message}') + self.buffer.append(f'[{"{:.3f}".format(epoch_percent*100)}% / {"{:.3f}".format(percent*100)}%] {message}') + + if lapsed: + self.epoch = self.epoch + 1 + self.it = int(self.epoch * (self.dataset_size / self.batch_size)) + + self.epoch_time_end = time.time() + self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start + self.epoch_time_start = time.time() + self.epoch_rate = f'[{"{:.3f}".format(self.epoch_time_delta)}s/epoch]' if self.epoch_time_delta >= 1 else f'[{"{:.3f}".format(1/self.epoch_time_delta)}epoch/s]' # I doubt anyone will have it/s rates, but its here + + #self.eta = (self.epochs - self.epoch) * self.epoch_time_delta + self.epoch_time_deltas = self.epoch_time_deltas + self.epoch_time_delta + self.epoch_taken = self.epoch_taken + 1 + self.eta = (self.epochs - self.epoch) * (self.epoch_time_deltas / self.epoch_taken) + try: + eta = str(timedelta(seconds=int(self.eta))) + self.eta_hhmmss = eta + except Exception as e: + pass percent = self.epoch / float(self.epochs) message = f'[{self.epoch}/{self.epochs}] [{self.it}/{self.its}] [ETA: {self.eta_hhmmss}] {self.epoch_rate} / {self.it_rate} {self.status}' @@ -561,35 +582,32 @@ class TrainingState(): if progress is not None: progress(percent, message) - if owner: - print(f'{"{:.3f}".format(percent*100)}% {message}') - self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}') + print(f'{"{:.3f}".format(percent*100)}% {message}') + self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}') if line.find('INFO: [epoch:') >= 0: - if owner: - # easily rip out our stats... - match = re.findall(r'\b([a-z_0-9]+?)\b: ([0-9]\.[0-9]+?e[+-]\d+)\b', line) - if match and len(match) > 0: - for k, v in match: - self.info[k] = float(v) + # easily rip out our stats... + match = re.findall(r'\b([a-z_0-9]+?)\b: ([0-9]\.[0-9]+?e[+-]\d+)\b', line) + if match and len(match) > 0: + for k, v in match: + self.info[k] = float(v) - if 'loss_gpt_total' in self.info: - self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}" - print(self.status) - self.buffer.append(self.status) + if 'loss_gpt_total' in self.info: + self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}" + print(self.status) + self.buffer.append(self.status) elif line.find('Saving models and training states') >= 0: - if owner: - self.checkpoint = self.checkpoint + 1 + self.checkpoint = self.checkpoint + 1 + percent = self.checkpoint / float(self.checkpoints) message = f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...' if progress is not None: progress(percent, message) - if owner: - print(f'{"{:.3f}".format(percent*100)}% {message}') - self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}') - if owner: - self.buffer = self.buffer[-buffer_size:] + print(f'{"{:.3f}".format(percent*100)}% {message}') + self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}') + + self.buffer = self.buffer[-buffer_size:] if verbose or not self.training_started: return "".join(self.buffer) @@ -609,7 +627,7 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress for line in iter(training_state.process.stdout.readline, ""): - res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress, owner=True ) + res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress ) print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}") if res: yield res @@ -628,7 +646,7 @@ def reconnect_training(verbose=False, buffer_size=8, progress=gr.Progress(track_ return "Training not in progress" for line in iter(training_state.process.stdout.readline, ""): - res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress, owner=True ) + res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress ) if res: yield res