Forgot to base print/save frequencies in terms of epochs in the UI, will get converted when saving the YAML

This commit is contained in:
mrq 2023-02-19 20:38:00 +00:00
parent 4694d622f4
commit 6260594a1e
2 changed files with 15 additions and 12 deletions

View File

@ -584,9 +584,9 @@ def calc_iterations( epochs, lines, batch_size ):
iterations = int(epochs * lines / float(batch_size))
return iterations
EPOCH_SCHEDULE = [ 9, 18, 25, 33 ]
def schedule_learning_rate( iterations ):
schedule = [ 9, 18, 25, 33 ]
return [int(iterations * d) for d in schedule]
return [int(iterations * d) for d in EPOCH_SCHEDULE]
def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate_schedule, mega_batch_factor, print_rate, save_rate, resume_path, voice ):
name = f"{voice}-finetune"
@ -611,12 +611,12 @@ def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate
iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size)
messages.append(f"For {epochs} epochs with {lines} lines, iterating for {iterations} steps")
if iterations < print_rate:
print_rate = iterations
if epochs < print_rate:
print_rate = epochs
messages.append(f"Print rate is too small for the given iteration step, clamping print rate to: {print_rate}")
if iterations < save_rate:
save_rate = iterations
if epochs < save_rate:
save_rate = epochs
messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {save_rate}")
if resume_path and not os.path.exists(resume_path):

View File

@ -209,6 +209,9 @@ def save_training_settings_proxy( epochs, batch_size, learning_rate, learning_ra
iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size)
messages.append(f"For {epochs} epochs with {lines} lines, iterating for {iterations} steps")
print_rate = int(print_rate * iterations / epochs)
save_rate = int(save_rate * iterations / epochs)
messages.append(save_training_settings(
iterations=iterations,
batch_size=batch_size,
@ -352,13 +355,13 @@ def setup_gradio():
with gr.Row():
with gr.Column():
training_settings = [
gr.Slider(label="Epochs", minimum=0, maximum=500, value=10),
gr.Slider(label="Batch Size", minimum=2, maximum=128, value=64),
gr.Number(label="Epochs", value=10, precision=0),
gr.Number(label="Batch Size", value=64, precision=0),
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
gr.Textbox(label="Learning Rate Schedule", placeholder="[ 100, 200, 300, 400 ]"),
gr.Slider(label="Mega Batch Factor", minimum=1, maximum=16, value=4, step=1),
gr.Number(label="Print Frequency", value=50),
gr.Number(label="Save Frequency", value=50),
gr.Textbox(label="Learning Rate Schedule", placeholder=str(EPOCH_SCHEDULE)),
gr.Number(label="Mega Batch Factor", value=4, precision=0),
gr.Number(label="Print Frequency per Epoch", value=5, precision=0),
gr.Number(label="Save Frequency per Epoch", value=5, precision=0),
gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"),
]
dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" )