From 4694d622f40ac20fb1084c81b8bf85e33ca0a877 Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 19 Feb 2023 20:22:03 +0000 Subject: [PATCH] doing something completely unrelated had me realize it's 1000x easier to just base things in terms of epochs, and calculate iteratsions from there --- src/utils.py | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++++ src/webui.py | 54 ++++++++++++++++++++++++------------------------- 2 files changed, 84 insertions(+), 27 deletions(-) diff --git a/src/utils.py b/src/utils.py index 912c380..23a3547 100755 --- a/src/utils.py +++ b/src/utils.py @@ -580,6 +580,63 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm return voice +def calc_iterations( epochs, lines, batch_size ): + iterations = int(epochs * lines / float(batch_size)) + return iterations + +def schedule_learning_rate( iterations ): + schedule = [ 9, 18, 25, 33 ] + return [int(iterations * d) for d in schedule] + +def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate_schedule, mega_batch_factor, print_rate, save_rate, resume_path, voice ): + name = f"{voice}-finetune" + dataset_name = f"{voice}-train" + dataset_path = f"./training/{voice}/train.txt" + validation_name = f"{voice}-val" + validation_path = f"./training/{voice}/train.txt" + + with open(dataset_path, 'r', encoding="utf-8") as f: + lines = len(f.readlines()) + + messages = [] + + if batch_size > lines: + batch_size = lines + messages.append(f"Batch size is larger than your dataset, clamping batch size to: {batch_size}") + + if batch_size / mega_batch_factor < 2: + mega_batch_factor = int(batch_size / 2) + messages.append(f"Mega batch factor is too large for the given batch size, clamping mega batch factor to: {mega_batch_factor}") + + iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size) + messages.append(f"For {epochs} epochs with {lines} lines, iterating for {iterations} steps") + + if iterations < print_rate: + print_rate = iterations + messages.append(f"Print rate is too small for the given iteration step, clamping print rate to: {print_rate}") + + if iterations < save_rate: + save_rate = iterations + messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {save_rate}") + + if resume_path and not os.path.exists(resume_path): + resume_path = None + messages.append("Resume path specified, but does not exist. Disabling...") + + learning_rate_schedule = schedule_learning_rate( iterations / epochs ) # faster learning schedule compared to just passing lines / batch_size due to truncating + messages.append(f"Suggesting best learning rate schedule for iterations: {learning_rate_schedule}") + + return ( + batch_size, + learning_rate, + learning_rate_schedule, + mega_batch_factor, + print_rate, + save_rate, + resume_path, + messages + ) + def save_training_settings( iterations=None, batch_size=None, learning_rate=None, learning_rate_schedule=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None ): settings = { "iterations": iterations if iterations else 500, diff --git a/src/webui.py b/src/webui.py index d1d892d..6d775f7 100755 --- a/src/webui.py +++ b/src/webui.py @@ -180,7 +180,21 @@ def read_generate_settings_proxy(file, saveAs='.temp'): def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ): return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress ) -def save_training_settings_proxy( iterations, batch_size, learning_rate, learning_rate_schedule, mega_batch_factor, print_rate, save_rate, resume_path, voice ): +def optimize_training_settings_proxy( *args, **kwargs ): + tup = optimize_training_settings(*args, **kwargs) + + return ( + gr.update(value=tup[0]), + gr.update(value=tup[1]), + gr.update(value=tup[2]), + gr.update(value=tup[3]), + gr.update(value=tup[4]), + gr.update(value=tup[5]), + gr.update(value=tup[6]), + "\n".join(tup[7]) + ) + +def save_training_settings_proxy( epochs, batch_size, learning_rate, learning_rate_schedule, mega_batch_factor, print_rate, save_rate, resume_path, voice ): name = f"{voice}-finetune" dataset_name = f"{voice}-train" dataset_path = f"./training/{voice}/train.txt" @@ -192,28 +206,11 @@ def save_training_settings_proxy( iterations, batch_size, learning_rate, learnin messages = [] - if batch_size > lines: - batch_size = lines - messages.append(f"Batch size is larger than your dataset, clamping batch size to: {batch_size}") - - - if batch_size / mega_batch_factor < 2: - mega_batch_factor = int(batch_size / 2) - messages.append(f"Mega batch factor is too large for the given batch size, clamping mega batch factor to: {mega_batch_factor}") + iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size) + messages.append(f"For {epochs} epochs with {lines} lines, iterating for {iterations} steps") - if iterations < print_rate: - print_rate = iterations - messages.append(f"Print rate is too small for the given iteration step, clamping print rate to: {print_rate}") - - if iterations < save_rate: - save_rate = iterations - messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {save_rate}") - - if resume_path and not os.path.exists(resume_path): - messages.append("Resume path specified, but does not exist. Disabling...") - resume_path = None - - messages.append(save_training_settings(iterations, + messages.append(save_training_settings( + iterations=iterations, batch_size=batch_size, learning_rate=learning_rate, learning_rate_schedule=learning_rate_schedule, @@ -355,19 +352,17 @@ def setup_gradio(): with gr.Row(): with gr.Column(): training_settings = [ - gr.Slider(label="Iterations", minimum=0, maximum=5000, value=500), + gr.Slider(label="Epochs", minimum=0, maximum=500, value=10), gr.Slider(label="Batch Size", minimum=2, maximum=128, value=64), gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6), - gr.Textbox(label="Learning Rate Schedule", placeholder="[ 200, 300, 400, 500 ]"), + gr.Textbox(label="Learning Rate Schedule", placeholder="[ 100, 200, 300, 400 ]"), gr.Slider(label="Mega Batch Factor", minimum=1, maximum=16, value=4, step=1), gr.Number(label="Print Frequency", value=50), gr.Number(label="Save Frequency", value=50), gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"), ] dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" ) - training_settings = training_settings + [ - dataset_list - ] + training_settings = training_settings + [ dataset_list ] refresh_dataset_list = gr.Button(value="Refresh Dataset List") """ training_settings = training_settings + [ @@ -380,6 +375,7 @@ def setup_gradio(): """ with gr.Column(): save_yaml_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) + optimize_yaml_button = gr.Button(value="Validate Training Configuration") save_yaml_button = gr.Button(value="Save Training Configuration") with gr.Tab("Run Training"): with gr.Row(): @@ -591,6 +587,10 @@ def setup_gradio(): inputs=None, outputs=dataset_list, ) + optimize_yaml_button.click(optimize_training_settings_proxy, + inputs=training_settings, + outputs=training_settings[1:8] + [save_yaml_output] #console_output + ) save_yaml_button.click(save_training_settings_proxy, inputs=training_settings, outputs=save_yaml_output #console_output