1
0

normalize validation batch size because i oom'd without it getting scaled

This commit is contained in:
mrq 2023-03-08 05:27:20 +00:00
parent d7e75a51cf
commit 8494628f3c

View File

@ -318,7 +318,7 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
save_rate = int(save_rate * iterations / epochs) save_rate = int(save_rate * iterations / epochs)
validation_rate = int(validation_rate * iterations / epochs) validation_rate = int(validation_rate * iterations / epochs)
validation_batch_size = batch_size validation_batch_size = int(batch_size / gradient_accumulation_size)
if iterations % save_rate != 0: if iterations % save_rate != 0:
adjustment = int(iterations / save_rate) * save_rate adjustment = int(iterations / save_rate) * save_rate
@ -333,6 +333,7 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
with open(validation_path, 'r', encoding="utf-8") as f: with open(validation_path, 'r', encoding="utf-8") as f:
validation_lines = len(f.readlines()) validation_lines = len(f.readlines())
if validation_lines < validation_batch_size: if validation_lines < validation_batch_size:
validation_batch_size = validation_lines validation_batch_size = validation_lines
messages.append(f"Batch size exceeds validation dataset size, clamping validation batch size to {validation_lines}") messages.append(f"Batch size exceeds validation dataset size, clamping validation batch size to {validation_lines}")