1
0
Fork 0

optimize batch sizes to be as evenly divisible as possible (noticed the calculated epochs mismatched the inputted epochs)

master
mrq 2023-02-19 21:06:14 +07:00
parent 6260594a1e
commit ee95616dfd
2 changed files with 17 additions and 9 deletions

@ -608,8 +608,12 @@ def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate
mega_batch_factor = int(batch_size / 2)
messages.append(f"Mega batch factor is too large for the given batch size, clamping mega batch factor to: {mega_batch_factor}")
if batch_size % lines != 0:
nearest_slice = int(lines / batch_size) + 1
batch_size = int(lines / nearest_slice)
messages.append(f"Batch size not neatly divisible by dataset size, adjusting batch size to: {batch_size} ({nearest_slice} steps per epoch)")
iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size)
messages.append(f"For {epochs} epochs with {lines} lines, iterating for {iterations} steps")
if epochs < print_rate:
print_rate = epochs
@ -623,8 +627,7 @@ def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate
resume_path = None
messages.append("Resume path specified, but does not exist. Disabling...")
learning_rate_schedule = schedule_learning_rate( iterations / epochs ) # faster learning schedule compared to just passing lines / batch_size due to truncating
messages.append(f"Suggesting best learning rate schedule for iterations: {learning_rate_schedule}")
messages.append(f"For {epochs} epochs with {lines} lines in batches of {batch_size}, iterating for {iterations} steps ({int(iterations / epochs)} steps per epoch)")
return (
batch_size,
@ -637,12 +640,12 @@ def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate
messages
)
def save_training_settings( iterations=None, batch_size=None, learning_rate=None, learning_rate_schedule=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None ):
def save_training_settings( iterations=None, batch_size=None, learning_rate=None, learning_rate_schedule=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None ):
settings = {
"iterations": iterations if iterations else 500,
"batch_size": batch_size if batch_size else 64,
"learning_rate": learning_rate if learning_rate else 1e-5,
"gen_lr_steps": learning_rate_schedule if learning_rate_schedule else [ 200, 300, 400, 500 ],
"gen_lr_steps": learning_rate_schedule if learning_rate_schedule else EPOCH_SCHEDULE,
"mega_batch_factor": mega_batch_factor if mega_batch_factor else 4,
"print_rate": print_rate if print_rate else 50,
"save_rate": save_rate if save_rate else 50,
@ -656,6 +659,7 @@ def save_training_settings( iterations=None, batch_size=None, learning_rate=None
'pretrain_model_gpt': "pretrain_model_gpt: './models/tortoise/autoregressive.pth'" if not resume_path else "# pretrain_model_gpt: './models/tortoise/autoregressive.pth'"
}
if not output_name:
output_name = f'{settings["name"]}.yaml'

@ -212,6 +212,10 @@ def save_training_settings_proxy( epochs, batch_size, learning_rate, learning_ra
print_rate = int(print_rate * iterations / epochs)
save_rate = int(save_rate * iterations / epochs)
if not learning_rate_schedule:
learning_rate_schedule = EPOCH_SCHEDULE
learning_rate_schedule = schedule_learning_rate( iterations / epochs )
messages.append(save_training_settings(
iterations=iterations,
batch_size=batch_size,
@ -355,8 +359,8 @@ def setup_gradio():
with gr.Row():
with gr.Column():
training_settings = [
gr.Number(label="Epochs", value=10, precision=0),
gr.Number(label="Batch Size", value=64, precision=0),
gr.Number(label="Epochs", value=500, precision=0),
gr.Number(label="Batch Size", value=128, precision=0),
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
gr.Textbox(label="Learning Rate Schedule", placeholder=str(EPOCH_SCHEDULE)),
gr.Number(label="Mega Batch Factor", value=4, precision=0),
@ -384,13 +388,13 @@ def setup_gradio():
with gr.Row():
with gr.Column():
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_list())
verbose_training = gr.Checkbox(label="Verbose Training")
training_buffer_size = gr.Slider(label="Buffer Size", minimum=4, maximum=32, value=8)
refresh_configs = gr.Button(value="Refresh Configurations")
start_training_button = gr.Button(value="Train")
stop_training_button = gr.Button(value="Stop")
with gr.Column():
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
verbose_training = gr.Checkbox(label="Verbose Console Output")
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
with gr.Tab("Settings"):
with gr.Row():
exec_inputs = []