forked from mrq/ai-voice-cloning
normalize validation batch size because i oom'd without it getting scaled
This commit is contained in:
parent
d7e75a51cf
commit
8494628f3c
|
@ -318,7 +318,7 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
|
|||
save_rate = int(save_rate * iterations / epochs)
|
||||
validation_rate = int(validation_rate * iterations / epochs)
|
||||
|
||||
validation_batch_size = batch_size
|
||||
validation_batch_size = int(batch_size / gradient_accumulation_size)
|
||||
|
||||
if iterations % save_rate != 0:
|
||||
adjustment = int(iterations / save_rate) * save_rate
|
||||
|
@ -333,6 +333,7 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
|
|||
with open(validation_path, 'r', encoding="utf-8") as f:
|
||||
validation_lines = len(f.readlines())
|
||||
|
||||
|
||||
if validation_lines < validation_batch_size:
|
||||
validation_batch_size = validation_lines
|
||||
messages.append(f"Batch size exceeds validation dataset size, clamping validation batch size to {validation_lines}")
|
||||
|
|
Loading…
Reference in New Issue
Block a user