fixed user inputted LR schedule not actually getting used (oops)

This commit is contained in:
mrq 2023-03-04 04:41:56 +00:00
parent 6d8c2dd459
commit 6d5e1e1a80
2 changed files with 6 additions and 3 deletions

View File

@ -935,8 +935,8 @@ def calc_iterations( epochs, lines, batch_size ):
return iterations return iterations
EPOCH_SCHEDULE = [ 9, 18, 25, 33 ] EPOCH_SCHEDULE = [ 9, 18, 25, 33 ]
def schedule_learning_rate( iterations ): def schedule_learning_rate( iterations, schedule=EPOCH_SCHEDULE ):
return [int(iterations * d) for d in EPOCH_SCHEDULE] return [int(iterations * d) for d in schedule]
def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, bnb, source_model, voice ): def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, mega_batch_factor, print_rate, save_rate, resume_path, half_p, bnb, source_model, voice ):
name = f"{voice}-finetune" name = f"{voice}-finetune"

View File

@ -307,7 +307,10 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
if not learning_rate_schedule: if not learning_rate_schedule:
learning_rate_schedule = EPOCH_SCHEDULE learning_rate_schedule = EPOCH_SCHEDULE
learning_rate_schedule = schedule_learning_rate( iterations / epochs ) elif isinstance(learning_rate_schedule,str):
learning_rate_schedule = json.loads(learning_rate_schedule)
learning_rate_schedule = schedule_learning_rate( iterations / epochs, learning_rate_schedule )
messages.append(save_training_settings( messages.append(save_training_settings(
iterations=iterations, iterations=iterations,