From 6260594a1e10a5729b020d57383d325ca204c53a Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 19 Feb 2023 20:38:00 +0000 Subject: [PATCH] Forgot to base print/save frequencies in terms of epochs in the UI, will get converted when saving the YAML --- src/utils.py | 12 ++++++------ src/webui.py | 15 +++++++++------ 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/utils.py b/src/utils.py index 23a3547..d5cdb60 100755 --- a/src/utils.py +++ b/src/utils.py @@ -584,9 +584,9 @@ def calc_iterations( epochs, lines, batch_size ): iterations = int(epochs * lines / float(batch_size)) return iterations +EPOCH_SCHEDULE = [ 9, 18, 25, 33 ] def schedule_learning_rate( iterations ): - schedule = [ 9, 18, 25, 33 ] - return [int(iterations * d) for d in schedule] + return [int(iterations * d) for d in EPOCH_SCHEDULE] def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate_schedule, mega_batch_factor, print_rate, save_rate, resume_path, voice ): name = f"{voice}-finetune" @@ -611,12 +611,12 @@ def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size) messages.append(f"For {epochs} epochs with {lines} lines, iterating for {iterations} steps") - if iterations < print_rate: - print_rate = iterations + if epochs < print_rate: + print_rate = epochs messages.append(f"Print rate is too small for the given iteration step, clamping print rate to: {print_rate}") - if iterations < save_rate: - save_rate = iterations + if epochs < save_rate: + save_rate = epochs messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {save_rate}") if resume_path and not os.path.exists(resume_path): diff --git a/src/webui.py b/src/webui.py index 6d775f7..bc64e27 100755 --- a/src/webui.py +++ b/src/webui.py @@ -209,6 +209,9 @@ def save_training_settings_proxy( epochs, batch_size, learning_rate, learning_ra iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size) messages.append(f"For {epochs} epochs with {lines} lines, iterating for {iterations} steps") + print_rate = int(print_rate * iterations / epochs) + save_rate = int(save_rate * iterations / epochs) + messages.append(save_training_settings( iterations=iterations, batch_size=batch_size, @@ -352,13 +355,13 @@ def setup_gradio(): with gr.Row(): with gr.Column(): training_settings = [ - gr.Slider(label="Epochs", minimum=0, maximum=500, value=10), - gr.Slider(label="Batch Size", minimum=2, maximum=128, value=64), + gr.Number(label="Epochs", value=10, precision=0), + gr.Number(label="Batch Size", value=64, precision=0), gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6), - gr.Textbox(label="Learning Rate Schedule", placeholder="[ 100, 200, 300, 400 ]"), - gr.Slider(label="Mega Batch Factor", minimum=1, maximum=16, value=4, step=1), - gr.Number(label="Print Frequency", value=50), - gr.Number(label="Save Frequency", value=50), + gr.Textbox(label="Learning Rate Schedule", placeholder=str(EPOCH_SCHEDULE)), + gr.Number(label="Mega Batch Factor", value=4, precision=0), + gr.Number(label="Print Frequency per Epoch", value=5, precision=0), + gr.Number(label="Save Frequency per Epoch", value=5, precision=0), gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"), ] dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" )