set validation to save rate and validation file if exists (need to test later)

This commit is contained in:
mrq 2023-03-07 20:38:31 +00:00
parent fe8bf7a9d1
commit e862169e7f
3 changed files with 17 additions and 8 deletions

View File

@ -122,7 +122,7 @@ train:
niter: ${iterations} niter: ${iterations}
warmup_iter: -1 warmup_iter: -1
mega_batch_factor: ${gradient_accumulation_size} mega_batch_factor: ${gradient_accumulation_size}
val_freq: ${iterations} val_freq: ${validation_rate}
ema_enabled: false # I really don't think EMA matters ema_enabled: false # I really don't think EMA matters

View File

@ -1228,10 +1228,7 @@ def schedule_learning_rate( iterations, schedule=EPOCH_SCHEDULE ):
def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, workers, source_model, voice ): def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, workers, source_model, voice ):
name = f"{voice}-finetune" name = f"{voice}-finetune"
dataset_name = f"{voice}-train"
dataset_path = f"./training/{voice}/train.txt" dataset_path = f"./training/{voice}/train.txt"
validation_name = f"{voice}-val"
validation_path = f"./training/{voice}/train.txt"
with open(dataset_path, 'r', encoding="utf-8") as f: with open(dataset_path, 'r', encoding="utf-8") as f:
lines = len(f.readlines()) lines = len(f.readlines())
@ -1303,7 +1300,7 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
messages messages
) )
def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None, batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None, bnb=None, workers=None, source_model=None ): def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None, batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, validation_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None, bnb=None, workers=None, source_model=None ):
if not source_model: if not source_model:
source_model = f"./models/tortoise/autoregressive{'_half' if half_p else ''}.pth" source_model = f"./models/tortoise/autoregressive{'_half' if half_p else ''}.pth"
@ -1320,6 +1317,7 @@ def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weig
"dataset_path": dataset_path if dataset_path else "./training/finetune/train.txt", "dataset_path": dataset_path if dataset_path else "./training/finetune/train.txt",
"validation_name": validation_name if validation_name else "finetune", "validation_name": validation_name if validation_name else "finetune",
"validation_path": validation_path if validation_path else "./training/finetune/train.txt", "validation_path": validation_path if validation_path else "./training/finetune/train.txt",
'validation_rate': validation_rate if validation_rate else iterations,
"text_ce_lr_weight": text_ce_lr_weight if text_ce_lr_weight else 0.01, "text_ce_lr_weight": text_ce_lr_weight if text_ce_lr_weight else 0.01,

View File

@ -300,7 +300,7 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
dataset_name = f"{voice}-train" dataset_name = f"{voice}-train"
dataset_path = f"./training/{voice}/train.txt" dataset_path = f"./training/{voice}/train.txt"
validation_name = f"{voice}-val" validation_name = f"{voice}-val"
validation_path = f"./training/{voice}/train.txt" validation_path = f"./training/{voice}/validation.txt"
with open(dataset_path, 'r', encoding="utf-8") as f: with open(dataset_path, 'r', encoding="utf-8") as f:
lines = len(f.readlines()) lines = len(f.readlines())
@ -312,6 +312,16 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
print_rate = int(print_rate * iterations / epochs) print_rate = int(print_rate * iterations / epochs)
save_rate = int(save_rate * iterations / epochs) save_rate = int(save_rate * iterations / epochs)
validation_rate = save_rate
if iterations % save_rate != 0:
adjustment = int(iterations / save_rate) * save_rate
messages.append(f"Iteration rate is not evenly divisible by save rate, adjusting: {iterations} => {adjustment}")
iterations = adjustment
if not os.path.exists(validation_path):
validation_rate = iterations
validation_path = dataset_path
if not learning_rate_schedule: if not learning_rate_schedule:
learning_rate_schedule = EPOCH_SCHEDULE learning_rate_schedule = EPOCH_SCHEDULE
@ -329,6 +339,7 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
gradient_accumulation_size=gradient_accumulation_size, gradient_accumulation_size=gradient_accumulation_size,
print_rate=print_rate, print_rate=print_rate,
save_rate=save_rate, save_rate=save_rate,
validation_rate=validation_rate,
name=name, name=name,
dataset_name=dataset_name, dataset_name=dataset_name,
dataset_path=dataset_path, dataset_path=dataset_path,
@ -559,7 +570,7 @@ def setup_gradio():
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True) verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
with gr.Row(): with gr.Row():
training_keep_x_past_checkpoints = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1) training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
training_gpu_count = gr.Number(label="GPUs", value=get_device_count()) training_gpu_count = gr.Number(label="GPUs", value=get_device_count())
with gr.Row(): with gr.Row():
start_training_button = gr.Button(value="Train") start_training_button = gr.Button(value="Train")
@ -777,7 +788,7 @@ def setup_gradio():
training_configs, training_configs,
verbose_training, verbose_training,
training_gpu_count, training_gpu_count,
training_keep_x_past_checkpoints, training_keep_x_past_datasets,
], ],
outputs=[ outputs=[
training_output, training_output,