output fixes, I'm not sure why ETA wasn't working but it works in testing

This commit is contained in:
mrq 2023-03-12 15:17:07 +00:00
parent 098d7ad635
commit 296129ba9c

View File

@ -638,7 +638,6 @@ class TrainingState():
self.loss_milestones = [ 1.0, 0.15, 0.05 ] self.loss_milestones = [ 1.0, 0.15, 0.05 ]
self.load_statistics()
if keep_x_past_checkpoints > 0: if keep_x_past_checkpoints > 0:
self.cleanup_old(keep=keep_x_past_checkpoints) self.cleanup_old(keep=keep_x_past_checkpoints)
if start: if start:
@ -676,7 +675,7 @@ class TrainingState():
self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it' self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it'
self.it_rates += it_rate self.it_rates += it_rate
epoch_rate = self.it_rates / self.it * self.epoch epoch_rate = self.it_rates / self.it * self.steps
if epoch_rate > 0: if epoch_rate > 0:
self.epoch_rate = f'{"{:.3f}".format(1/epoch_rate)}epoch/s' if 0 < epoch_rate and epoch_rate < 1 else f'{"{:.3f}".format(epoch_rate)}s/epoch' self.epoch_rate = f'{"{:.3f}".format(1/epoch_rate)}epoch/s' if 0 < epoch_rate and epoch_rate < 1 else f'{"{:.3f}".format(epoch_rate)}s/epoch'
@ -710,6 +709,72 @@ class TrainingState():
return data return data
def get_status(self):
message = None
self.metrics['rate'] = []
if self.epoch_rate:
self.metrics['rate'].append(self.epoch_rate)
if self.it_rate and self.epoch_rate[:-7] != self.it_rate[:-4]:
self.metrics['rate'].append(self.it_rate)
self.metrics['rate'] = ", ".join(self.metrics['rate'])
eta_hhmmss = self.eta_hhmmss if self.eta_hhmmss else "?"
self.metrics['loss'] = []
if 'lr' in self.info:
self.metrics['loss'].append(f'LR: {"{:.3e}".format(self.info["lr"])}')
if len(self.losses) > 0:
self.metrics['loss'].append(f'Loss: {"{:.3f}".format(self.losses[-1]["value"])}')
if len(self.losses) >= 2:
deriv = 0
accum_length = len(self.losses)//2 # i *guess* this is fine when you think about it
loss_value = self.losses[-1]["value"]
for i in range(accum_length):
d1_loss = self.losses[accum_length-i-1]["value"]
d2_loss = self.losses[accum_length-i-2]["value"]
dloss = (d2_loss - d1_loss)
d1_step = self.losses[accum_length-i-1]["epoch"]
d2_step = self.losses[accum_length-i-2]["epoch"]
dstep = (d2_step - d1_step)
if dstep == 0:
continue
inst_deriv = dloss / dstep
deriv += inst_deriv
deriv = deriv / accum_length
if deriv != 0: # dloss < 0:
next_milestone = None
for milestone in self.loss_milestones:
if loss_value > milestone:
next_milestone = milestone
break
if next_milestone:
# tfw can do simple calculus but not basic algebra in my head
est_its = (next_milestone - loss_value) / deriv
if est_its >= 0:
self.metrics['loss'].append(f'Est. milestone {next_milestone} in: {int(est_its)}its')
else:
est_loss = inst_deriv * (self.its - self.it) + loss_value
if est_loss >= 0:
self.metrics['loss'].append(f'Est. final loss: {"{:.3f}".format(est_loss)}')
self.metrics['loss'] = ", ".join(self.metrics['loss'])
message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}]\n[{self.metrics['loss']}]"
if self.nan_detected:
message = f"[!NaN DETECTED! {self.nan_detected}] {message}"
return message
def load_statistics(self, update=False): def load_statistics(self, update=False):
if not os.path.isdir(f'{self.dataset_dir}/'): if not os.path.isdir(f'{self.dataset_dir}/'):
return return
@ -720,6 +785,7 @@ class TrainingState():
if not update: if not update:
self.statistics['loss'] = [] self.statistics['loss'] = []
self.statistics['lr'] = [] self.statistics['lr'] = []
self.it_rates = 0
logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ]) logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ])
if update: if update:
@ -742,12 +808,13 @@ class TrainingState():
if "it" not in data: if "it" not in data:
continue continue
step = data['it'] it = data['it']
if update and step <= self.last_info_check_at: if update and it <= self.last_info_check_at:
continue continue
self.parse_metrics(data) self.parse_metrics(data)
# print(f"Iterations Left: {self.its - self.it} | Elapsed Time: {self.it_rates} | Time Remaining: {self.eta} | Message: {self.get_status()}")
self.last_info_check_at = highest_step self.last_info_check_at = highest_step
@ -795,11 +862,10 @@ class TrainingState():
self.checkpoints = int((self.its - self.it) / self.config['logger']['save_checkpoint_freq']) self.checkpoints = int((self.its - self.it) / self.config['logger']['save_checkpoint_freq'])
self.load_statistics()
should_return = True should_return = True
else: else:
message = None
data = None
# INFO: Training Metrics: {"loss_text_ce": 4.308311939239502, "loss_mel_ce": 2.1610655784606934, "loss_gpt_total": 2.204148769378662, "lr": 0.0001, "it": 2, "step": 1, "steps": 1, "epoch": 1, "iteration_rate": 0.10700102965037028} # INFO: Training Metrics: {"loss_text_ce": 4.308311939239502, "loss_mel_ce": 2.1610655784606934, "loss_gpt_total": 2.204148769378662, "lr": 0.0001, "it": 2, "step": 1, "steps": 1, "epoch": 1, "iteration_rate": 0.10700102965037028}
if line.find('INFO: Training Metrics:') >= 0: if line.find('INFO: Training Metrics:') >= 0:
data = json.loads(line.split("INFO: Training Metrics:")[-1]) data = json.loads(line.split("INFO: Training Metrics:")[-1])
@ -809,72 +875,11 @@ class TrainingState():
data['mode'] = "validation" data['mode'] = "validation"
if data is not None: if data is not None:
self.parse_metrics( data )
should_return = True
if ': nan' in line and not self.nan_detected: if ': nan' in line and not self.nan_detected:
self.nan_detected = self.it self.nan_detected = self.it
self.metrics['rate'] = [] self.parse_metrics( data )
if self.epoch_rate: message = self.get_status()
self.metrics['rate'].append(self.epoch_rate)
if self.it_rate and self.epoch_rate[:-7] != self.it_rate[:-4]:
self.metrics['rate'].append(self.it_rate)
self.metrics['rate'] = ", ".join(self.metrics['rate'])
eta_hhmmss = self.eta_hhmmss if self.eta_hhmmss else "?"
self.metrics['loss'] = []
if 'lr' in self.info:
self.metrics['loss'].append(f'LR: {"{:.3e}".format(self.info["lr"])}')
if len(self.losses) > 0:
self.metrics['loss'].append(f'Loss: {"{:.3f}".format(self.losses[-1]["value"])}')
if len(self.losses) >= 2:
deriv = 0
accum_length = len(self.losses)//2 # i *guess* this is fine when you think about it
loss_value = self.losses[-1]["value"]
for i in range(accum_length):
d1_loss = self.losses[accum_length-i-1]["value"]
d2_loss = self.losses[accum_length-i-2]["value"]
dloss = (d2_loss - d1_loss)
d1_step = self.losses[accum_length-i-1]["step"]
d2_step = self.losses[accum_length-i-2]["step"]
dstep = (d2_step - d1_step)
if dstep == 0:
continue
inst_deriv = dloss / dstep
deriv += inst_deriv
deriv = deriv / accum_length
if deriv != 0: # dloss < 0:
next_milestone = None
for milestone in self.loss_milestones:
if loss_value > milestone:
next_milestone = milestone
break
if next_milestone:
# tfw can do simple calculus but not basic algebra in my head
est_its = (next_milestone - loss_value) / deriv
if est_its >= 0:
self.metrics['loss'].append(f'Est. milestone {next_milestone} in: {int(est_its)}its')
else:
est_loss = inst_deriv * (self.its - self.it) + loss_value
if est_loss >= 0:
self.metrics['loss'].append(f'Est. final loss: {"{:.3f}".format(est_loss)}')
self.metrics['loss'] = ", ".join(self.metrics['loss'])
message = f"[{self.metrics['epoch']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}]\n[{self.metrics['loss']}]"
if self.nan_detected:
message = f"[!NaN DETECTED! {self.nan_detected}] {message}"
if message: if message:
percent = self.it / float(self.its) # self.epoch / float(self.epochs) percent = self.it / float(self.its) # self.epoch / float(self.epochs)
@ -882,6 +887,7 @@ class TrainingState():
progress(percent, message) progress(percent, message)
self.buffer.append(f'[{"{:.3f}".format(percent*100)}%] {message}') self.buffer.append(f'[{"{:.3f}".format(percent*100)}%] {message}')
should_return = True
if verbose and not self.training_started: if verbose and not self.training_started:
should_return = True should_return = True
@ -948,6 +954,10 @@ def update_training_dataplot(config_path=None):
if not training_state: if not training_state:
if config_path: if config_path:
training_state = TrainingState(config_path=config_path, start=False) training_state = TrainingState(config_path=config_path, start=False)
training_state.load_statistics()
message = training_state.get_status()
print(message)
if len(training_state.statistics['loss']) > 0: if len(training_state.statistics['loss']) > 0:
losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'value', 'type'], width=500, height=350,) losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.epochs], x="epoch", y="value", title="Loss Metrics", color="type", tooltip=['epoch', 'value', 'type'], width=500, height=350,)
if len(training_state.statistics['lr']) > 0: if len(training_state.statistics['lr']) > 0: