From 788a957f797f99d23e7074b8240a6173a7829b64 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 6 Mar 2023 00:44:29 +0000 Subject: [PATCH] stretch loss plot to target iteration just so its not so misleading with the scale --- src/utils.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/utils.py b/src/utils.py index 41dd15f..eb95d19 100755 --- a/src/utils.py +++ b/src/utils.py @@ -924,12 +924,6 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro return_code = training_state.process.wait() training_state = None -def get_training_losses(): - global training_state - if not training_state or not training_state.statistics: - return - return pd.DataFrame(training_state.statistics) - def update_training_dataplot(config_path=None): global training_state update = None @@ -938,12 +932,12 @@ def update_training_dataplot(config_path=None): if config_path: training_state = TrainingState(config_path=config_path, start=False) if training_state.statistics: - update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics)) + update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=600, height=350,) del training_state training_state = None elif training_state.statistics: training_state.load_losses() - update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics)) + update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=600, height=350,) return update