forked from mrq/ai-voice-cloning
Merge pull request 'keep_training' (#118) from zim33/ai-voice-cloning:keep_training into master
Reviewed-on: mrq/ai-voice-cloning#118
This commit is contained in:
commit
1ac278e885
|
@ -752,8 +752,8 @@ class TrainingState():
|
||||||
|
|
||||||
models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.dataset_dir}/models/') if d[-8:] == "_gpt.pth" ])
|
models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.dataset_dir}/models/') if d[-8:] == "_gpt.pth" ])
|
||||||
states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.dataset_dir}/training_state/') if d[-6:] == ".state" ])
|
states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.dataset_dir}/training_state/') if d[-6:] == ".state" ])
|
||||||
remove_models = models[:-2]
|
remove_models = models[:-keep]
|
||||||
remove_states = states[:-2]
|
remove_states = states[:-keep]
|
||||||
|
|
||||||
for d in remove_models:
|
for d in remove_models:
|
||||||
path = f'{self.dataset_dir}/models/{d}_gpt.pth'
|
path = f'{self.dataset_dir}/models/{d}_gpt.pth'
|
||||||
|
@ -898,6 +898,9 @@ class TrainingState():
|
||||||
if should_return:
|
if should_return:
|
||||||
result = "".join(self.buffer) if not self.training_started else message
|
result = "".join(self.buffer) if not self.training_started else message
|
||||||
|
|
||||||
|
if keep_x_past_checkpoints > 0:
|
||||||
|
self.cleanup_old(keep=keep_x_past_checkpoints)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
result,
|
result,
|
||||||
percent,
|
percent,
|
||||||
|
|
|
@ -497,7 +497,7 @@ def setup_gradio():
|
||||||
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||||
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
||||||
|
|
||||||
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
keep_x_past_checkpoints = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
start_training_button = gr.Button(value="Train")
|
start_training_button = gr.Button(value="Train")
|
||||||
stop_training_button = gr.Button(value="Stop")
|
stop_training_button = gr.Button(value="Stop")
|
||||||
|
@ -708,7 +708,7 @@ def setup_gradio():
|
||||||
inputs=[
|
inputs=[
|
||||||
training_configs,
|
training_configs,
|
||||||
verbose_training,
|
verbose_training,
|
||||||
training_keep_x_past_datasets,
|
keep_x_past_checkpoints,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
training_output,
|
training_output,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user