diff --git a/src/utils.py b/src/utils.py index d23f046..5c1c449 100755 --- a/src/utils.py +++ b/src/utils.py @@ -550,8 +550,11 @@ class TrainingState(): self.eta_hhmmss = "?" self.last_info_check_at = 0 + self.statistics = [] self.losses = [] + self.loss_milestones = [ 1.0, 0.15, 0.05 ] + self.load_losses() if keep_x_past_datasets > 0: self.cleanup_old(keep=keep_x_past_datasets) @@ -578,7 +581,7 @@ class TrainingState(): highest_step = self.last_info_check_at if not update: - self.losses = [] + self.statistics = [] if use_tensorboard: logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ]) @@ -596,9 +599,12 @@ class TrainingState(): if update and s.step <= self.last_info_check_at: continue highest_step = max( highest_step, s.step ) - self.losses.append( { "step": s.step, "value": s.value, "type": key } ) + self.statistics.append( { "step": s.step, "value": s.value, "type": key } ) + + if key == 'loss_gpt_total': + self.losses.append( { "step": s.step, "value": s.value, "type": key } ) + except Exception as e: - print("Failed to parse event log:", log) pass else: @@ -630,7 +636,10 @@ class TrainingState(): if update and int(k) <= self.last_info_check_at: continue highest_step = max( highest_step, s.step ) - self.losses.append({ "step": int(k), "value": infos[k][key], "type": key }) + self.statistics.append({ "step": int(k), "value": infos[k][key], "type": key }) + + if key == "loss_gpt_total": + self.losses.append({ "step": int(k), "value": infos[k][key], "type": key }) self.last_info_check_at = highest_step @@ -739,20 +748,36 @@ class TrainingState(): metric_loss.append(f'Loss: {"{:3f}".format(self.losses[-1]["value"])}') if len(self.losses) >= 2: - delta_loss = self.losses[-2]["value"] - self.losses[-1]["value"] - delta_step = self.losses[-2]["step"] - self.losses[-1]["step"] - - inst_deriv = delta_loss / delta_step - est_loss = delta_loss + (self.its - self.it) * inst_deriv - metric_loss.append(f'Est. Final Loss: {"{:3f}".format(est_loss)}') - - print(delta_loss, delta_step, inst_deriv, est_loss) - + # i can probably do a """riemann sum""" to get a better derivative, but the instantaneous one works fine + d1_loss = self.losses[-1]["value"] + d2_loss = self.losses[-2]["value"] + dloss = d2_loss - d1_loss + + d1_step = self.losses[-1]["step"] + d2_step = self.losses[-2]["step"] + dstep = d2_step - d1_step + + # don't bother if the loss went up + if dloss < 0: + its_remain = self.its - self.it + inst_deriv = dloss / dstep + + next_milestone = None + for milestone in self.loss_milestones: + if d1_loss > milestone: + next_milestone = milestone + break + + if next_milestone: + # tfw can do simple calculus but not basic algebra in my head + est_its = (next_milestone - d1_loss) * (dstep / dloss) + metric_loss.append(f'Est. milestone {next_milestone} in: {int(est_its)}its') + else: + est_loss = inst_deriv * its_remain + d1_loss + metric_loss.append(f'Est. final loss: {"{:3f}".format(est_loss)}') metric_loss = ", ".join(metric_loss) - - message = f'[{metric_step}] [{metric_rate}] [ETA: {eta_hhmmss}] [{metric_loss}]' if lapsed: @@ -859,9 +884,9 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro def get_training_losses(): global training_state - if not training_state or not training_state.losses: + if not training_state or not training_state.statistics: return - return pd.DataFrame(training_state.losses) + return pd.DataFrame(training_state.statistics) def update_training_dataplot(config_path=None): global training_state @@ -870,13 +895,13 @@ def update_training_dataplot(config_path=None): if not training_state: if config_path: training_state = TrainingState(config_path=config_path, start=False) - if training_state.losses: - update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses)) + if training_state.statistics: + update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics)) del training_state training_state = None - elif training_state.losses: + elif training_state.statistics: training_state.load_losses() - update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses)) + update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics)) return update @@ -943,7 +968,7 @@ def whisper_transcribe( file, language=None ): if not args.whisper_cpp: if not language: language = None - + return whisper_model.transcribe(file, language=language) res = whisper_model.transcribe(file)