From 9e320a34c8b97bfca0586cadedd1be2e3ead8dc4 Mon Sep 17 00:00:00 2001 From: tigi6346 Date: Sun, 12 Mar 2023 08:00:03 +0200 Subject: [PATCH 1/2] Fixed Keep X Previous States --- src/webui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/webui.py b/src/webui.py index 9dd6475..927f262 100755 --- a/src/webui.py +++ b/src/webui.py @@ -497,7 +497,7 @@ def setup_gradio(): training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) verbose_training = gr.Checkbox(label="Verbose Console Output", value=True) - training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1) + keep_x_past_checkpoints = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1) with gr.Row(): start_training_button = gr.Button(value="Train") stop_training_button = gr.Button(value="Stop") @@ -708,7 +708,7 @@ def setup_gradio(): inputs=[ training_configs, verbose_training, - training_keep_x_past_datasets, + keep_x_past_checkpoints, ], outputs=[ training_output, From 29b3d1ae1d76ff645f3c314abe19810df019f2ff Mon Sep 17 00:00:00 2001 From: tigi6346 Date: Sun, 12 Mar 2023 08:01:08 +0200 Subject: [PATCH 2/2] Fixed Keep X Previous States --- src/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/utils.py b/src/utils.py index 6da2af4..802918e 100755 --- a/src/utils.py +++ b/src/utils.py @@ -752,8 +752,8 @@ class TrainingState(): models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.dataset_dir}/models/') if d[-8:] == "_gpt.pth" ]) states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.dataset_dir}/training_state/') if d[-6:] == ".state" ]) - remove_models = models[:-2] - remove_states = states[:-2] + remove_models = models[:-keep] + remove_states = states[:-keep] for d in remove_models: path = f'{self.dataset_dir}/models/{d}_gpt.pth' @@ -898,6 +898,9 @@ class TrainingState(): if should_return: result = "".join(self.buffer) if not self.training_started else message + if keep_x_past_checkpoints > 0: + self.cleanup_old(keep=keep_x_past_checkpoints) + return ( result, percent,