From 8b9c9e1bbf4ef8fb03364942746954dc64fe8570 Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 5 Mar 2023 18:53:12 +0000 Subject: [PATCH] remove redundant stats, add showing LR --- src/utils.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/utils.py b/src/utils.py index fd0851c..69ad56d 100755 --- a/src/utils.py +++ b/src/utils.py @@ -758,7 +758,11 @@ class TrainingState(): except Exception as e: pass - self.metrics['step'] = [f"{self.epoch}/{self.epochs}", f"{self.it}/{self.its}", f"{step}/{steps}"] + self.metrics['step'] = [f"{self.epoch}/{self.epochs}"] + if self.epochs != self.its: + self.metric.append(f"{self.it}/{self.its}") + if steps > 1: + self.metric.append(f"{step}/{steps}") self.metrics['step'] = ", ".join(self.metrics['step']) if lapsed: @@ -786,7 +790,7 @@ class TrainingState(): self.metrics['rate'] = [] if self.epoch_rate: self.metrics['rate'].append(self.epoch_rate) - if self.it_rate: + if self.it_rate and self.epoch_rate != self.it_rate: self.metrics['rate'].append(self.it_rate) self.metrics['rate'] = ", ".join(self.metrics['rate']) @@ -802,6 +806,10 @@ class TrainingState(): pass self.metrics['loss'] = [] + + if 'learning_rate_gpt_0' in self.info: + self.metrics['loss'].append(f'LR: {"{:.9f}".format(self.info["learning_rate_gpt_0"])}') + if len(self.losses) > 0: self.metrics['loss'].append(f'Loss: {"{:.3f}".format(self.losses[-1]["value"])}') @@ -845,7 +853,7 @@ class TrainingState(): self.metrics['loss'] = ", ".join(self.metrics['loss']) - message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}] [{self.metrics['loss']}]" + message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}]\n[{self.metrics['loss']}]" if message: percent = self.it / float(self.its) # self.epoch / float(self.epochs) @@ -949,6 +957,7 @@ def stop_training(): try: children = [p.info for p in psutil.process_iter(attrs=['pid', 'name', 'cmdline']) if './src/train.py' in p.info['cmdline']] except Exception as e: + print(e) pass training_state.process.stdout.close()