From 8b4da29d5fda2e07b5439f62917208b34423c4a3 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 25 Feb 2023 13:55:25 +0000 Subject: [PATCH] csome adjustments to the training output parser, now updates per iteration for really large batches (like the one I'm doing for a dataset size of 19420) --- src/utils.py | 123 +++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 90 insertions(+), 33 deletions(-) diff --git a/src/utils.py b/src/utils.py index d8150ed..02798d9 100755 --- a/src/utils.py +++ b/src/utils.py @@ -470,14 +470,21 @@ class TrainingState(): self.epoch_rate = "" self.epoch_time_start = 0 self.epoch_time_end = 0 + + self.it_rate = "" + self.it_time_start = 0 + self.it_time_end = 0 + self.last_step = 0 + self.eta = "?" self.eta_hhmmss = "?" print("Spawning process: ", " ".join(self.cmd)) self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) - def parse(self, line, verbose=False, buffer_size=8, progress=None): - self.buffer.append(f'{line}') + def parse(self, line, verbose=False, buffer_size=8, progress=None, owner=True): + if owner: + self.buffer.append(f'{line}') # rip out iteration info if not self.training_started: @@ -492,47 +499,97 @@ class TrainingState(): if match and len(match) > 0: self.it = int(match[0].replace(",", "")) else: - if line.find('%|') > 0 and not self.open_state: - self.open_state = True - elif line.find('100%|') == 0 and self.open_state: - self.open_state = False - self.epoch = self.epoch + 1 + lapsed = line.find('100%|') == 0 and self.open_state - self.epoch_time_end = time.time() - self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start - self.epoch_time_start = time.time() - self.epoch_rate = f'[{"{:.3f}".format(self.epoch_time_delta)}s/epoch]' if self.epoch_time_delta >= 1 else f'[{"{:.3f}".format(1/self.epoch_time_delta)}epoch/s]' # I doubt anyone will have it/s rates, but its here - self.eta = (self.epochs - self.epoch) * self.epoch_time_delta - self.eta_hhmmss = str(timedelta(seconds=int(self.eta))) + if line.find('%|') > 0: + match = re.findall(r' +?(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line) + if match and len(match) > 0: + match = match[0] + percent = int(match[0])/100.0 + progressbar = match[1] + step = int(match[2]) + steps = int(match[3]) + elapsed = match[4] + until = match[5] + rate = match[6] + + epoch_percent = self.epoch / float(self.epochs) + + if owner: + last_step = self.last_step + self.last_step = step + if last_step < step: + self.it = self.it + (step - last_step) + + if last_step > step and step == 0: + lapsed = True + + self.it_time_end = time.time() + self.it_time_delta = self.it_time_end-self.it_time_start + self.it_time_start = time.time() + self.it_rate = f'[{"{:.3f}".format(self.it_time_delta)}s/it]' if self.it_time_delta >= 1 else f'[{"{:.3f}".format(1/self.it_time_delta)}it/s]' + + self.eta = (self.its - self.it) * self.it_time_delta + self.eta_hhmmss = str(timedelta(seconds=int(self.eta))) + + message = f'[{self.epoch}/{self.epochs}] [{self.it}/{self.its}] [ETA: {self.eta_hhmmss}] {self.epoch_rate} / {self.it_rate} {self.status}' + if progress is not None: + progress(epoch_percent, message) + if owner: + # print(f'{"{:.3f}".format(percent*100)}% {message}') + self.buffer.append(f'[{"{:.3f}".format(epoch_percent*100)}% / {"{:.3f}".format(percent*100)}%] {message}') + + if line.find('%|') > 0 and not self.open_state: + if owner: + self.open_state = True + elif lapsed: + if owner: + self.open_state = False + self.epoch = self.epoch + 1 + self.it = int(self.epoch * (self.dataset_size / self.batch_size)) + + self.epoch_time_end = time.time() + self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start + self.epoch_time_start = time.time() + self.epoch_rate = f'[{"{:.3f}".format(self.epoch_time_delta)}s/epoch]' if self.epoch_time_delta >= 1 else f'[{"{:.3f}".format(1/self.epoch_time_delta)}epoch/s]' # I doubt anyone will have it/s rates, but its here + self.eta = (self.epochs - self.epoch) * self.epoch_time_delta + self.eta_hhmmss = str(timedelta(seconds=int(self.eta))) percent = self.epoch / float(self.epochs) - message = f'[{self.epoch}/{self.epochs}] [ETA: {self.eta_hhmmss}] {self.epoch_rate} {self.status}' - print(f'{"{:.3f}".format(percent*100)}% {message}') + message = f'[{self.epoch}/{self.epochs}] [{self.it}/{self.its}] [ETA: {self.eta_hhmmss}] {self.epoch_rate} / {self.it_rate} {self.status}' + if progress is not None: progress(percent, message) - self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}') + + if owner: + print(f'{"{:.3f}".format(percent*100)}% {message}') + self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}') if line.find('INFO: [epoch:') >= 0: - # easily rip out our stats... - match = re.findall(r'\b([a-z_0-9]+?)\b: ([0-9]\.[0-9]+?e[+-]\d+)\b', line) - if match and len(match) > 0: - for k, v in match: - self.info[k] = float(v) - - if 'loss_gpt_total' in self.info: - self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}" - print(self.status) - self.buffer.append(self.status) + if owner: + # easily rip out our stats... + match = re.findall(r'\b([a-z_0-9]+?)\b: ([0-9]\.[0-9]+?e[+-]\d+)\b', line) + if match and len(match) > 0: + for k, v in match: + self.info[k] = float(v) + + if 'loss_gpt_total' in self.info: + self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}" + print(self.status) + self.buffer.append(self.status) elif line.find('Saving models and training states') >= 0: - self.checkpoint = self.checkpoint + 1 + if owner: + self.checkpoint = self.checkpoint + 1 percent = self.checkpoint / float(self.checkpoints) message = f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...' - print(f'{"{:.3f}".format(percent*100)}% {message}') if progress is not None: progress(percent, message) - self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}') + if owner: + print(f'{"{:.3f}".format(percent*100)}% {message}') + self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}') - self.buffer = self.buffer[-buffer_size:] + if owner: + self.buffer = self.buffer[-buffer_size:] if verbose or not self.training_started: return "".join(self.buffer) @@ -552,7 +609,7 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress for line in iter(training_state.process.stdout.readline, ""): - res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress ) + res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress, owner=True ) print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}") if res: yield res @@ -565,13 +622,13 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress #if return_code: # raise subprocess.CalledProcessError(return_code, cmd) -def reconnect_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)): +def reconnect_training(verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)): global training_state if not training_state or not training_state.process: return "Training not in progress" for line in iter(training_state.process.stdout.readline, ""): - res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress ) + res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress, owner=True ) if res: yield res