From 6d5e1e1a80d7c70a64bc3dee79653a2688859963 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 4 Mar 2023 04:41:56 +0000 Subject: [PATCH] fixed user inputted LR schedule not actually getting used (oops) --- src/utils.py | 4 ++-- src/webui.py | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/utils.py b/src/utils.py index f6f88a3..5c173a4 100755 --- a/src/utils.py +++ b/src/utils.py @@ -935,8 +935,8 @@ def calc_iterations( epochs, lines, batch_size ): return iterations EPOCH_SCHEDULE = [ 9, 18, 25, 33 ] -def schedule_learning_rate( iterations ): - return [int(iterations * d) for d in EPOCH_SCHEDULE] +def schedule_learning_rate( iterations, schedule=EPOCH_SCHEDULE ): + return [int(iterations * d) for d in schedule] def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, bnb, source_model, voice ): name = f"{voice}-finetune" diff --git a/src/webui.py b/src/webui.py index 396bde1..c94c59e 100755 --- a/src/webui.py +++ b/src/webui.py @@ -307,7 +307,10 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear if not learning_rate_schedule: learning_rate_schedule = EPOCH_SCHEDULE - learning_rate_schedule = schedule_learning_rate( iterations / epochs ) + elif isinstance(learning_rate_schedule,str): + learning_rate_schedule = json.loads(learning_rate_schedule) + + learning_rate_schedule = schedule_learning_rate( iterations / epochs, learning_rate_schedule ) messages.append(save_training_settings( iterations=iterations,