|
|
@ -209,6 +209,9 @@ def save_training_settings_proxy( epochs, batch_size, learning_rate, learning_ra
|
|
|
|
iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size)
|
|
|
|
iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size)
|
|
|
|
messages.append(f"For {epochs} epochs with {lines} lines, iterating for {iterations} steps")
|
|
|
|
messages.append(f"For {epochs} epochs with {lines} lines, iterating for {iterations} steps")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print_rate = int(print_rate * iterations / epochs)
|
|
|
|
|
|
|
|
save_rate = int(save_rate * iterations / epochs)
|
|
|
|
|
|
|
|
|
|
|
|
messages.append(save_training_settings(
|
|
|
|
messages.append(save_training_settings(
|
|
|
|
iterations=iterations,
|
|
|
|
iterations=iterations,
|
|
|
|
batch_size=batch_size,
|
|
|
|
batch_size=batch_size,
|
|
|
@ -352,13 +355,13 @@ def setup_gradio():
|
|
|
|
with gr.Row():
|
|
|
|
with gr.Row():
|
|
|
|
with gr.Column():
|
|
|
|
with gr.Column():
|
|
|
|
training_settings = [
|
|
|
|
training_settings = [
|
|
|
|
gr.Slider(label="Epochs", minimum=0, maximum=500, value=10),
|
|
|
|
gr.Number(label="Epochs", value=10, precision=0),
|
|
|
|
gr.Slider(label="Batch Size", minimum=2, maximum=128, value=64),
|
|
|
|
gr.Number(label="Batch Size", value=64, precision=0),
|
|
|
|
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
|
|
|
|
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
|
|
|
|
gr.Textbox(label="Learning Rate Schedule", placeholder="[ 100, 200, 300, 400 ]"),
|
|
|
|
gr.Textbox(label="Learning Rate Schedule", placeholder=str(EPOCH_SCHEDULE)),
|
|
|
|
gr.Slider(label="Mega Batch Factor", minimum=1, maximum=16, value=4, step=1),
|
|
|
|
gr.Number(label="Mega Batch Factor", value=4, precision=0),
|
|
|
|
gr.Number(label="Print Frequency", value=50),
|
|
|
|
gr.Number(label="Print Frequency per Epoch", value=5, precision=0),
|
|
|
|
gr.Number(label="Save Frequency", value=50),
|
|
|
|
gr.Number(label="Save Frequency per Epoch", value=5, precision=0),
|
|
|
|
gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"),
|
|
|
|
gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"),
|
|
|
|
]
|
|
|
|
]
|
|
|
|
dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" )
|
|
|
|
dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" )
|
|
|
|