forked from mrq/ai-voice-cloning
optimize batch sizes to be as evenly divisible as possible (noticed the calculated epochs mismatched the inputted epochs)
This commit is contained in:
parent
6260594a1e
commit
ee95616dfd
14
src/utils.py
14
src/utils.py
|
@ -608,8 +608,12 @@ def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate
|
|||
mega_batch_factor = int(batch_size / 2)
|
||||
messages.append(f"Mega batch factor is too large for the given batch size, clamping mega batch factor to: {mega_batch_factor}")
|
||||
|
||||
if batch_size % lines != 0:
|
||||
nearest_slice = int(lines / batch_size) + 1
|
||||
batch_size = int(lines / nearest_slice)
|
||||
messages.append(f"Batch size not neatly divisible by dataset size, adjusting batch size to: {batch_size} ({nearest_slice} steps per epoch)")
|
||||
|
||||
iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size)
|
||||
messages.append(f"For {epochs} epochs with {lines} lines, iterating for {iterations} steps")
|
||||
|
||||
if epochs < print_rate:
|
||||
print_rate = epochs
|
||||
|
@ -623,8 +627,7 @@ def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate
|
|||
resume_path = None
|
||||
messages.append("Resume path specified, but does not exist. Disabling...")
|
||||
|
||||
learning_rate_schedule = schedule_learning_rate( iterations / epochs ) # faster learning schedule compared to just passing lines / batch_size due to truncating
|
||||
messages.append(f"Suggesting best learning rate schedule for iterations: {learning_rate_schedule}")
|
||||
messages.append(f"For {epochs} epochs with {lines} lines in batches of {batch_size}, iterating for {iterations} steps ({int(iterations / epochs)} steps per epoch)")
|
||||
|
||||
return (
|
||||
batch_size,
|
||||
|
@ -637,12 +640,12 @@ def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate
|
|||
messages
|
||||
)
|
||||
|
||||
def save_training_settings( iterations=None, batch_size=None, learning_rate=None, learning_rate_schedule=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None ):
|
||||
def save_training_settings( iterations=None, batch_size=None, learning_rate=None, learning_rate_schedule=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None ):
|
||||
settings = {
|
||||
"iterations": iterations if iterations else 500,
|
||||
"batch_size": batch_size if batch_size else 64,
|
||||
"learning_rate": learning_rate if learning_rate else 1e-5,
|
||||
"gen_lr_steps": learning_rate_schedule if learning_rate_schedule else [ 200, 300, 400, 500 ],
|
||||
"gen_lr_steps": learning_rate_schedule if learning_rate_schedule else EPOCH_SCHEDULE,
|
||||
"mega_batch_factor": mega_batch_factor if mega_batch_factor else 4,
|
||||
"print_rate": print_rate if print_rate else 50,
|
||||
"save_rate": save_rate if save_rate else 50,
|
||||
|
@ -656,6 +659,7 @@ def save_training_settings( iterations=None, batch_size=None, learning_rate=None
|
|||
'pretrain_model_gpt': "pretrain_model_gpt: './models/tortoise/autoregressive.pth'" if not resume_path else "# pretrain_model_gpt: './models/tortoise/autoregressive.pth'"
|
||||
}
|
||||
|
||||
|
||||
if not output_name:
|
||||
output_name = f'{settings["name"]}.yaml'
|
||||
|
||||
|
|
12
src/webui.py
12
src/webui.py
|
@ -212,6 +212,10 @@ def save_training_settings_proxy( epochs, batch_size, learning_rate, learning_ra
|
|||
print_rate = int(print_rate * iterations / epochs)
|
||||
save_rate = int(save_rate * iterations / epochs)
|
||||
|
||||
if not learning_rate_schedule:
|
||||
learning_rate_schedule = EPOCH_SCHEDULE
|
||||
learning_rate_schedule = schedule_learning_rate( iterations / epochs )
|
||||
|
||||
messages.append(save_training_settings(
|
||||
iterations=iterations,
|
||||
batch_size=batch_size,
|
||||
|
@ -355,8 +359,8 @@ def setup_gradio():
|
|||
with gr.Row():
|
||||
with gr.Column():
|
||||
training_settings = [
|
||||
gr.Number(label="Epochs", value=10, precision=0),
|
||||
gr.Number(label="Batch Size", value=64, precision=0),
|
||||
gr.Number(label="Epochs", value=500, precision=0),
|
||||
gr.Number(label="Batch Size", value=128, precision=0),
|
||||
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
|
||||
gr.Textbox(label="Learning Rate Schedule", placeholder=str(EPOCH_SCHEDULE)),
|
||||
gr.Number(label="Mega Batch Factor", value=4, precision=0),
|
||||
|
@ -384,13 +388,13 @@ def setup_gradio():
|
|||
with gr.Row():
|
||||
with gr.Column():
|
||||
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_list())
|
||||
verbose_training = gr.Checkbox(label="Verbose Training")
|
||||
training_buffer_size = gr.Slider(label="Buffer Size", minimum=4, maximum=32, value=8)
|
||||
refresh_configs = gr.Button(value="Refresh Configurations")
|
||||
start_training_button = gr.Button(value="Train")
|
||||
stop_training_button = gr.Button(value="Stop")
|
||||
with gr.Column():
|
||||
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||
verbose_training = gr.Checkbox(label="Verbose Console Output")
|
||||
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
|
||||
with gr.Tab("Settings"):
|
||||
with gr.Row():
|
||||
exec_inputs = []
|
||||
|
|
Loading…
Reference in New Issue
Block a user