From 843bfbfb968f3e5eeb2754dc3ea5dc1b55c0855e Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 18 Feb 2023 14:51:00 +0000 Subject: [PATCH] Simplified generating training YAML, cleaned it up, training output is cleaned up and will "autoscroll" (only show the last 8 lines, refer to console for a full trace if needed) --- src/utils.py | 20 ++++++++++++++------ src/webui.py | 42 ++++++++++++++++++++++++++++++++++++++---- 2 files changed, 52 insertions(+), 10 deletions(-) diff --git a/src/utils.py b/src/utils.py index e52b2db..9339a82 100755 --- a/src/utils.py +++ b/src/utils.py @@ -449,9 +449,9 @@ def run_training(config_path): training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) buffer=[] for line in iter(training_process.stdout.readline, ""): - buffer.append(line) - print(line[:-1]) - yield "".join(buffer) + buffer.append(f'[{datetime.now().isoformat()}] {line}') + print(f"[Training] {line[:-1]}") + yield "".join(buffer[-8:]) training_process.stdout.close() return_code = training_process.wait() @@ -498,7 +498,7 @@ def setup_tortoise(restart=False): print("TorToiSe initialized, ready for generation.") return tts -def save_training_settings( batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None ): +def save_training_settings( batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None ): settings = { "batch_size": batch_size if batch_size else 128, "learning_rate": learning_rate if learning_rate else 1e-5, @@ -510,7 +510,11 @@ def save_training_settings( batch_size=None, learning_rate=None, print_rate=None "validation_name": validation_name if validation_name else "finetune", "validation_path": validation_path if validation_path else "./training/finetune/train.txt", } - outfile = f'./training/{settings["name"]}.yaml' + + if not output_name: + output_name = f'{settings["name"]}.yaml' + + outfile = f'./training/{output_name}' with open(f'./models/.template.yaml', 'r', encoding="utf-8") as f: yaml = f.read() @@ -724,9 +728,13 @@ def get_autoregressive_models(dir="./models/finetuned/"): os.makedirs(dir, exist_ok=True) return [get_model_path('autoregressive.pth')] + sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ]) +def get_dataset_list(dir="./training/"): + return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.txt" in os.listdir(os.path.join(dir, d)) ]) + +def get_training_list(dir="./training/"): + return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.yaml" in os.listdir(os.path.join(dir, d)) ]) def update_autoregressive_model(path_name): - global tts if not tts: raise Exception("TTS is uninitialized or still initializing...") diff --git a/src/webui.py b/src/webui.py index 0ce4759..1186b1e 100755 --- a/src/webui.py +++ b/src/webui.py @@ -131,7 +131,7 @@ def get_training_configs(): return configs def update_training_configs(): - return gr.update(choices=get_training_configs()) + return gr.update(choices=get_training_list()) history_headers = { "Name": "", @@ -201,6 +201,24 @@ def read_generate_settings_proxy(file, saveAs='.temp'): def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ): return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress ) +def save_training_settings_proxy( batch_size, learning_rate, print_rate, save_rate, voice ): + name = f"{voice}-finetune" + dataset_name = f"{voice}-train" + dataset_path = f"./training/{voice}/train.txt" + validation_name = f"{voice}-val" + validation_path = f"./training/{voice}/train.txt" + + with open(dataset_path, 'r', encoding="utf-8") as f: + lines = len(f.readlines()) + + if batch_size > lines: + print("Batch size is larger than your dataset, clamping...") + batch_size = lines + + out_name = f"{voice}/train.yaml" + + return save_training_settings(batch_size, learning_rate, print_rate, save_rate, name, dataset_name, dataset_path, validation_name, validation_path, out_name ) + def update_voices(): return ( gr.Dropdown.update(choices=get_voice_list()), @@ -333,6 +351,12 @@ def setup_gradio(): gr.Number(label="Print Frequency", value=50), gr.Number(label="Save Frequency", value=50), ] + dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" ) + training_settings = training_settings + [ + dataset_list + ] + refresh_dataset_list = gr.Button(value="Refresh Dataset List") + """ training_settings = training_settings + [ gr.Textbox(label="Training Name", placeholder="finetune"), gr.Textbox(label="Dataset Name", placeholder="finetune"), @@ -340,13 +364,14 @@ def setup_gradio(): gr.Textbox(label="Validation Name", placeholder="finetune"), gr.Textbox(label="Validation Path", placeholder="./training/finetune/train.txt"), ] + """ with gr.Column(): save_yaml_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) save_yaml_button = gr.Button(value="Save Training Configuration") with gr.Tab("Run Training"): with gr.Row(): with gr.Column(): - training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_configs()) + training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_list()) refresh_configs = gr.Button(value="Refresh Configurations") start_training_button = gr.Button(value="Train") stop_training_button = gr.Button(value="Stop") @@ -524,7 +549,11 @@ def setup_gradio(): outputs=input_settings ) - refresh_configs.click(update_training_configs,inputs=None,outputs=training_configs) + refresh_configs.click( + lambda: gr.update(choices=get_training_list()), + inputs=None, + outputs=training_configs + ) start_training_button.click(run_training, inputs=training_configs, outputs=training_output #console_output @@ -538,7 +567,12 @@ def setup_gradio(): inputs=dataset_settings, outputs=prepare_dataset_output #console_output ) - save_yaml_button.click(save_training_settings, + refresh_dataset_list.click( + lambda: gr.update(choices=get_dataset_list()), + inputs=None, + outputs=dataset_list, + ) + save_yaml_button.click(save_training_settings_proxy, inputs=training_settings, outputs=save_yaml_output #console_output )