stretch loss plot to target iteration just so its not so misleading with the scale

This commit is contained in:
mrq 2023-03-06 00:44:29 +00:00
parent 5be14abc21
commit 788a957f79

View File

@ -924,12 +924,6 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro
return_code = training_state.process.wait() return_code = training_state.process.wait()
training_state = None training_state = None
def get_training_losses():
global training_state
if not training_state or not training_state.statistics:
return
return pd.DataFrame(training_state.statistics)
def update_training_dataplot(config_path=None): def update_training_dataplot(config_path=None):
global training_state global training_state
update = None update = None
@ -938,12 +932,12 @@ def update_training_dataplot(config_path=None):
if config_path: if config_path:
training_state = TrainingState(config_path=config_path, start=False) training_state = TrainingState(config_path=config_path, start=False)
if training_state.statistics: if training_state.statistics:
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics)) update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=600, height=350,)
del training_state del training_state
training_state = None training_state = None
elif training_state.statistics: elif training_state.statistics:
training_state.load_losses() training_state.load_losses()
update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics)) update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=600, height=350,)
return update return update