Merge pull request 'keep_training' (#118) from zim33/ai-voice-cloning:keep_training into master

Reviewed-on: mrq/ai-voice-cloning#118
This commit is contained in:
mrq 2023-03-12 06:47:01 +00:00
commit 1ac278e885
2 changed files with 7 additions and 4 deletions

View File

@ -752,8 +752,8 @@ class TrainingState():
models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.dataset_dir}/models/') if d[-8:] == "_gpt.pth" ]) models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.dataset_dir}/models/') if d[-8:] == "_gpt.pth" ])
states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.dataset_dir}/training_state/') if d[-6:] == ".state" ]) states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.dataset_dir}/training_state/') if d[-6:] == ".state" ])
remove_models = models[:-2] remove_models = models[:-keep]
remove_states = states[:-2] remove_states = states[:-keep]
for d in remove_models: for d in remove_models:
path = f'{self.dataset_dir}/models/{d}_gpt.pth' path = f'{self.dataset_dir}/models/{d}_gpt.pth'
@ -898,6 +898,9 @@ class TrainingState():
if should_return: if should_return:
result = "".join(self.buffer) if not self.training_started else message result = "".join(self.buffer) if not self.training_started else message
if keep_x_past_checkpoints > 0:
self.cleanup_old(keep=keep_x_past_checkpoints)
return ( return (
result, result,
percent, percent,

View File

@ -497,7 +497,7 @@ def setup_gradio():
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True) verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1) keep_x_past_checkpoints = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
with gr.Row(): with gr.Row():
start_training_button = gr.Button(value="Train") start_training_button = gr.Button(value="Train")
stop_training_button = gr.Button(value="Stop") stop_training_button = gr.Button(value="Stop")
@ -708,7 +708,7 @@ def setup_gradio():
inputs=[ inputs=[
training_configs, training_configs,
verbose_training, verbose_training,
training_keep_x_past_datasets, keep_x_past_checkpoints,
], ],
outputs=[ outputs=[
training_output, training_output,