diff --git a/models/.template.yaml b/models/.template.yaml index 28fbd2c..258ee89 100755 --- a/models/.template.yaml +++ b/models/.template.yaml @@ -29,7 +29,7 @@ datasets: val: # I really do not care about validation right now name: ${validation_name} n_workers: ${workers} - batch_size: ${batch_size} + batch_size: ${validation_batch_size} mode: paired_voice_audio path: ${validation_path} fetcher_mode: ['lj'] @@ -131,7 +131,7 @@ train: lr_gamma: 0.5 eval: - pure: True + pure: ${validation_enabled} output_state: gen logger: diff --git a/src/utils.py b/src/utils.py index 7cfc505..ce318a5 100755 --- a/src/utils.py +++ b/src/utils.py @@ -1347,7 +1347,7 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni messages ) -def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None, batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, validation_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None, bnb=None, workers=None, source_model=None ): +def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None, batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, validation_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, validation_batch_size=None, output_name=None, resume_path=None, half_p=None, bnb=None, workers=None, source_model=None ): if not source_model: source_model = f"./models/tortoise/autoregressive{'_half' if half_p else ''}.pth" @@ -1365,6 +1365,8 @@ def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weig "validation_name": validation_name if validation_name else "finetune", "validation_path": validation_path if validation_path else "./training/finetune/train.txt", 'validation_rate': validation_rate if validation_rate else iterations, + "validation_batch_size": validation_batch_size if validation_batch_size else batch_size, + 'validation_enabled': "true", "text_ce_lr_weight": text_ce_lr_weight if text_ce_lr_weight else 0.01, @@ -1382,6 +1384,11 @@ def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weig else: settings['resume_state'] = f"# resume_state: './training/{name if name else 'finetune'}/training_state/#.state'" + # also disable validation if it doesn't make sense to do it + if settings['dataset_path'] == settings['validation_path'] or not os.path.exists(settings['validation_path']): + settings['validation_enabled'] = 'false' + + if half_p: if not os.path.exists(get_halfp_model_path()): convert_to_halfp() diff --git a/src/webui.py b/src/webui.py index 9e85230..1ded6c6 100755 --- a/src/webui.py +++ b/src/webui.py @@ -318,6 +318,8 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear save_rate = int(save_rate * iterations / epochs) validation_rate = int(validation_rate * iterations / epochs) + validation_batch_size = batch_size + if iterations % save_rate != 0: adjustment = int(iterations / save_rate) * save_rate messages.append(f"Iteration rate is not evenly divisible by save rate, adjusting: {iterations} => {adjustment}") @@ -326,6 +328,14 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear if not os.path.exists(validation_path): validation_rate = iterations validation_path = dataset_path + messages.append("Validation not found, disabling validation...") + else: + with open(validation_path, 'r', encoding="utf-8") as f: + validation_lines = len(f.readlines()) + + if validation_lines < validation_batch_size: + validation_batch_size = validation_lines + messages.append(f"Batch size exceeds validation dataset size, clamping validation batch size to {validation_lines}") if not learning_rate_schedule: learning_rate_schedule = EPOCH_SCHEDULE @@ -349,6 +359,7 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear dataset_path=dataset_path, validation_name=validation_name, validation_path=validation_path, + validation_batch_size=validation_batch_size, output_name=f"{voice}/train.yaml", resume_path=resume_path, half_p=half_p,