forked from mrq/ai-voice-cloning
normalize validation batch size because i oom'd without it getting scaled
This commit is contained in:
parent
d7e75a51cf
commit
8494628f3c
|
@ -318,7 +318,7 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
|
||||||
save_rate = int(save_rate * iterations / epochs)
|
save_rate = int(save_rate * iterations / epochs)
|
||||||
validation_rate = int(validation_rate * iterations / epochs)
|
validation_rate = int(validation_rate * iterations / epochs)
|
||||||
|
|
||||||
validation_batch_size = batch_size
|
validation_batch_size = int(batch_size / gradient_accumulation_size)
|
||||||
|
|
||||||
if iterations % save_rate != 0:
|
if iterations % save_rate != 0:
|
||||||
adjustment = int(iterations / save_rate) * save_rate
|
adjustment = int(iterations / save_rate) * save_rate
|
||||||
|
@ -333,6 +333,7 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
|
||||||
with open(validation_path, 'r', encoding="utf-8") as f:
|
with open(validation_path, 'r', encoding="utf-8") as f:
|
||||||
validation_lines = len(f.readlines())
|
validation_lines = len(f.readlines())
|
||||||
|
|
||||||
|
|
||||||
if validation_lines < validation_batch_size:
|
if validation_lines < validation_batch_size:
|
||||||
validation_batch_size = validation_lines
|
validation_batch_size = validation_lines
|
||||||
messages.append(f"Batch size exceeds validation dataset size, clamping validation batch size to {validation_lines}")
|
messages.append(f"Batch size exceeds validation dataset size, clamping validation batch size to {validation_lines}")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user