@ -608,8 +608,12 @@ def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate
mega_batch_factor = int ( batch_size / 2 )
messages . append ( f " Mega batch factor is too large for the given batch size, clamping mega batch factor to: { mega_batch_factor } " )
if batch_size % lines != 0 :
nearest_slice = int ( lines / batch_size ) + 1
batch_size = int ( lines / nearest_slice )
messages . append ( f " Batch size not neatly divisible by dataset size, adjusting batch size to: { batch_size } ( { nearest_slice } steps per epoch) " )
iterations = calc_iterations ( epochs = epochs , lines = lines , batch_size = batch_size )
messages . append ( f " For { epochs } epochs with { lines } lines, iterating for { iterations } steps " )
if epochs < print_rate :
print_rate = epochs
@ -623,8 +627,7 @@ def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate
resume_path = None
messages . append ( " Resume path specified, but does not exist. Disabling... " )
learning_rate_schedule = schedule_learning_rate ( iterations / epochs ) # faster learning schedule compared to just passing lines / batch_size due to truncating
messages . append ( f " Suggesting best learning rate schedule for iterations: { learning_rate_schedule } " )
messages . append ( f " For { epochs } epochs with { lines } lines in batches of { batch_size } , iterating for { iterations } steps ( { int ( iterations / epochs ) } steps per epoch) " )
return (
batch_size ,
@ -642,7 +645,7 @@ def save_training_settings( iterations=None, batch_size=None, learning_rate=None
" iterations " : iterations if iterations else 500 ,
" batch_size " : batch_size if batch_size else 64 ,
" learning_rate " : learning_rate if learning_rate else 1e-5 ,
" gen_lr_steps " : learning_rate_schedule if learning_rate_schedule else [ 200 , 300 , 400 , 500 ] ,
" gen_lr_steps " : learning_rate_schedule if learning_rate_schedule else EPOCH_SCHEDULE ,
" mega_batch_factor " : mega_batch_factor if mega_batch_factor else 4 ,
" print_rate " : print_rate if print_rate else 50 ,
" save_rate " : save_rate if save_rate else 50 ,
@ -656,6 +659,7 @@ def save_training_settings( iterations=None, batch_size=None, learning_rate=None
' pretrain_model_gpt ' : " pretrain_model_gpt: ' ./models/tortoise/autoregressive.pth ' " if not resume_path else " # pretrain_model_gpt: ' ./models/tortoise/autoregressive.pth ' "
}
if not output_name :
output_name = f ' { settings [ " name " ] } .yaml '