remove redundant stats, add showing LR

This commit is contained in:
mrq 2023-03-05 18:53:12 +00:00
parent 0231550287
commit 8b9c9e1bbf

View File

@ -758,7 +758,11 @@ class TrainingState():
except Exception as e: except Exception as e:
pass pass
self.metrics['step'] = [f"{self.epoch}/{self.epochs}", f"{self.it}/{self.its}", f"{step}/{steps}"] self.metrics['step'] = [f"{self.epoch}/{self.epochs}"]
if self.epochs != self.its:
self.metric.append(f"{self.it}/{self.its}")
if steps > 1:
self.metric.append(f"{step}/{steps}")
self.metrics['step'] = ", ".join(self.metrics['step']) self.metrics['step'] = ", ".join(self.metrics['step'])
if lapsed: if lapsed:
@ -786,7 +790,7 @@ class TrainingState():
self.metrics['rate'] = [] self.metrics['rate'] = []
if self.epoch_rate: if self.epoch_rate:
self.metrics['rate'].append(self.epoch_rate) self.metrics['rate'].append(self.epoch_rate)
if self.it_rate: if self.it_rate and self.epoch_rate != self.it_rate:
self.metrics['rate'].append(self.it_rate) self.metrics['rate'].append(self.it_rate)
self.metrics['rate'] = ", ".join(self.metrics['rate']) self.metrics['rate'] = ", ".join(self.metrics['rate'])
@ -802,6 +806,10 @@ class TrainingState():
pass pass
self.metrics['loss'] = [] self.metrics['loss'] = []
if 'learning_rate_gpt_0' in self.info:
self.metrics['loss'].append(f'LR: {"{:.9f}".format(self.info["learning_rate_gpt_0"])}')
if len(self.losses) > 0: if len(self.losses) > 0:
self.metrics['loss'].append(f'Loss: {"{:.3f}".format(self.losses[-1]["value"])}') self.metrics['loss'].append(f'Loss: {"{:.3f}".format(self.losses[-1]["value"])}')
@ -845,7 +853,7 @@ class TrainingState():
self.metrics['loss'] = ", ".join(self.metrics['loss']) self.metrics['loss'] = ", ".join(self.metrics['loss'])
message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}] [{self.metrics['loss']}]" message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}]\n[{self.metrics['loss']}]"
if message: if message:
percent = self.it / float(self.its) # self.epoch / float(self.epochs) percent = self.it / float(self.its) # self.epoch / float(self.epochs)
@ -949,6 +957,7 @@ def stop_training():
try: try:
children = [p.info for p in psutil.process_iter(attrs=['pid', 'name', 'cmdline']) if './src/train.py' in p.info['cmdline']] children = [p.info for p in psutil.process_iter(attrs=['pid', 'name', 'cmdline']) if './src/train.py' in p.info['cmdline']]
except Exception as e: except Exception as e:
print(e)
pass pass
training_state.process.stdout.close() training_state.process.stdout.close()