diff --git a/src/utils.py b/src/utils.py index d5cdb60..9bac52b 100755 --- a/src/utils.py +++ b/src/utils.py @@ -608,8 +608,12 @@ def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate mega_batch_factor = int(batch_size / 2) messages.append(f"Mega batch factor is too large for the given batch size, clamping mega batch factor to: {mega_batch_factor}") + if batch_size % lines != 0: + nearest_slice = int(lines / batch_size) + 1 + batch_size = int(lines / nearest_slice) + messages.append(f"Batch size not neatly divisible by dataset size, adjusting batch size to: {batch_size} ({nearest_slice} steps per epoch)") + iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size) - messages.append(f"For {epochs} epochs with {lines} lines, iterating for {iterations} steps") if epochs < print_rate: print_rate = epochs @@ -623,8 +627,7 @@ def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate resume_path = None messages.append("Resume path specified, but does not exist. Disabling...") - learning_rate_schedule = schedule_learning_rate( iterations / epochs ) # faster learning schedule compared to just passing lines / batch_size due to truncating - messages.append(f"Suggesting best learning rate schedule for iterations: {learning_rate_schedule}") + messages.append(f"For {epochs} epochs with {lines} lines in batches of {batch_size}, iterating for {iterations} steps ({int(iterations / epochs)} steps per epoch)") return ( batch_size, @@ -637,12 +640,12 @@ def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate messages ) -def save_training_settings( iterations=None, batch_size=None, learning_rate=None, learning_rate_schedule=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None ): +def save_training_settings( iterations=None, batch_size=None, learning_rate=None, learning_rate_schedule=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None ): settings = { "iterations": iterations if iterations else 500, "batch_size": batch_size if batch_size else 64, "learning_rate": learning_rate if learning_rate else 1e-5, - "gen_lr_steps": learning_rate_schedule if learning_rate_schedule else [ 200, 300, 400, 500 ], + "gen_lr_steps": learning_rate_schedule if learning_rate_schedule else EPOCH_SCHEDULE, "mega_batch_factor": mega_batch_factor if mega_batch_factor else 4, "print_rate": print_rate if print_rate else 50, "save_rate": save_rate if save_rate else 50, @@ -656,6 +659,7 @@ def save_training_settings( iterations=None, batch_size=None, learning_rate=None 'pretrain_model_gpt': "pretrain_model_gpt: './models/tortoise/autoregressive.pth'" if not resume_path else "# pretrain_model_gpt: './models/tortoise/autoregressive.pth'" } + if not output_name: output_name = f'{settings["name"]}.yaml' diff --git a/src/webui.py b/src/webui.py index bc64e27..83aff5f 100755 --- a/src/webui.py +++ b/src/webui.py @@ -212,6 +212,10 @@ def save_training_settings_proxy( epochs, batch_size, learning_rate, learning_ra print_rate = int(print_rate * iterations / epochs) save_rate = int(save_rate * iterations / epochs) + if not learning_rate_schedule: + learning_rate_schedule = EPOCH_SCHEDULE + learning_rate_schedule = schedule_learning_rate( iterations / epochs ) + messages.append(save_training_settings( iterations=iterations, batch_size=batch_size, @@ -355,8 +359,8 @@ def setup_gradio(): with gr.Row(): with gr.Column(): training_settings = [ - gr.Number(label="Epochs", value=10, precision=0), - gr.Number(label="Batch Size", value=64, precision=0), + gr.Number(label="Epochs", value=500, precision=0), + gr.Number(label="Batch Size", value=128, precision=0), gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6), gr.Textbox(label="Learning Rate Schedule", placeholder=str(EPOCH_SCHEDULE)), gr.Number(label="Mega Batch Factor", value=4, precision=0), @@ -384,13 +388,13 @@ def setup_gradio(): with gr.Row(): with gr.Column(): training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_list()) - verbose_training = gr.Checkbox(label="Verbose Training") - training_buffer_size = gr.Slider(label="Buffer Size", minimum=4, maximum=32, value=8) refresh_configs = gr.Button(value="Refresh Configurations") start_training_button = gr.Button(value="Train") stop_training_button = gr.Button(value="Stop") with gr.Column(): training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) + verbose_training = gr.Checkbox(label="Verbose Console Output") + training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8) with gr.Tab("Settings"): with gr.Row(): exec_inputs = []