forked from mrq/ai-voice-cloning
Forgot to base print/save frequencies in terms of epochs in the UI, will get converted when saving the YAML
This commit is contained in:
parent
4694d622f4
commit
6260594a1e
12
src/utils.py
12
src/utils.py
|
@ -584,9 +584,9 @@ def calc_iterations( epochs, lines, batch_size ):
|
|||
iterations = int(epochs * lines / float(batch_size))
|
||||
return iterations
|
||||
|
||||
EPOCH_SCHEDULE = [ 9, 18, 25, 33 ]
|
||||
def schedule_learning_rate( iterations ):
|
||||
schedule = [ 9, 18, 25, 33 ]
|
||||
return [int(iterations * d) for d in schedule]
|
||||
return [int(iterations * d) for d in EPOCH_SCHEDULE]
|
||||
|
||||
def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate_schedule, mega_batch_factor, print_rate, save_rate, resume_path, voice ):
|
||||
name = f"{voice}-finetune"
|
||||
|
@ -611,12 +611,12 @@ def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate
|
|||
iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size)
|
||||
messages.append(f"For {epochs} epochs with {lines} lines, iterating for {iterations} steps")
|
||||
|
||||
if iterations < print_rate:
|
||||
print_rate = iterations
|
||||
if epochs < print_rate:
|
||||
print_rate = epochs
|
||||
messages.append(f"Print rate is too small for the given iteration step, clamping print rate to: {print_rate}")
|
||||
|
||||
if iterations < save_rate:
|
||||
save_rate = iterations
|
||||
if epochs < save_rate:
|
||||
save_rate = epochs
|
||||
messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {save_rate}")
|
||||
|
||||
if resume_path and not os.path.exists(resume_path):
|
||||
|
|
15
src/webui.py
15
src/webui.py
|
@ -209,6 +209,9 @@ def save_training_settings_proxy( epochs, batch_size, learning_rate, learning_ra
|
|||
iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size)
|
||||
messages.append(f"For {epochs} epochs with {lines} lines, iterating for {iterations} steps")
|
||||
|
||||
print_rate = int(print_rate * iterations / epochs)
|
||||
save_rate = int(save_rate * iterations / epochs)
|
||||
|
||||
messages.append(save_training_settings(
|
||||
iterations=iterations,
|
||||
batch_size=batch_size,
|
||||
|
@ -352,13 +355,13 @@ def setup_gradio():
|
|||
with gr.Row():
|
||||
with gr.Column():
|
||||
training_settings = [
|
||||
gr.Slider(label="Epochs", minimum=0, maximum=500, value=10),
|
||||
gr.Slider(label="Batch Size", minimum=2, maximum=128, value=64),
|
||||
gr.Number(label="Epochs", value=10, precision=0),
|
||||
gr.Number(label="Batch Size", value=64, precision=0),
|
||||
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
|
||||
gr.Textbox(label="Learning Rate Schedule", placeholder="[ 100, 200, 300, 400 ]"),
|
||||
gr.Slider(label="Mega Batch Factor", minimum=1, maximum=16, value=4, step=1),
|
||||
gr.Number(label="Print Frequency", value=50),
|
||||
gr.Number(label="Save Frequency", value=50),
|
||||
gr.Textbox(label="Learning Rate Schedule", placeholder=str(EPOCH_SCHEDULE)),
|
||||
gr.Number(label="Mega Batch Factor", value=4, precision=0),
|
||||
gr.Number(label="Print Frequency per Epoch", value=5, precision=0),
|
||||
gr.Number(label="Save Frequency per Epoch", value=5, precision=0),
|
||||
gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"),
|
||||
]
|
||||
dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" )
|
||||
|
|
Loading…
Reference in New Issue
Block a user