@ -1347,7 +1347,7 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
messages
)
def save_training_settings ( iterations = None , learning_rate = None , text_ce_lr_weight = None , learning_rate_schedule = None , batch_size = None , gradient_accumulation_size = None , print_rate = None , save_rate = None , validation_rate = None , name = None , dataset_name = None , dataset_path = None , validation_name = None , validation_path = None , output_name= None , resume_path = None , half_p = None , bnb = None , workers = None , source_model = None ) :
def save_training_settings ( iterations = None , learning_rate = None , text_ce_lr_weight = None , learning_rate_schedule = None , batch_size = None , gradient_accumulation_size = None , print_rate = None , save_rate = None , validation_rate = None , name = None , dataset_name = None , dataset_path = None , validation_name = None , validation_path = None , validation_batch_size= None , output_name= None , resume_path = None , half_p = None , bnb = None , workers = None , source_model = None ) :
if not source_model :
source_model = f " ./models/tortoise/autoregressive { ' _half ' if half_p else ' ' } .pth "
@ -1365,6 +1365,8 @@ def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weig
" validation_name " : validation_name if validation_name else " finetune " ,
" validation_path " : validation_path if validation_path else " ./training/finetune/train.txt " ,
' validation_rate ' : validation_rate if validation_rate else iterations ,
" validation_batch_size " : validation_batch_size if validation_batch_size else batch_size ,
' validation_enabled ' : " true " ,
" text_ce_lr_weight " : text_ce_lr_weight if text_ce_lr_weight else 0.01 ,
@ -1382,6 +1384,11 @@ def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weig
else :
settings [ ' resume_state ' ] = f " # resume_state: ' ./training/ { name if name else ' finetune ' } /training_state/#.state ' "
# also disable validation if it doesn't make sense to do it
if settings [ ' dataset_path ' ] == settings [ ' validation_path ' ] or not os . path . exists ( settings [ ' validation_path ' ] ) :
settings [ ' validation_enabled ' ] = ' false '
if half_p :
if not os . path . exists ( get_halfp_model_path ( ) ) :
convert_to_halfp ( )