|
|
|
@ -575,14 +575,17 @@ def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate
|
|
|
|
|
batch_size = lines
|
|
|
|
|
messages.append(f"Batch size is larger than your dataset, clamping batch size to: {batch_size}")
|
|
|
|
|
|
|
|
|
|
if batch_size / mega_batch_factor < 2:
|
|
|
|
|
mega_batch_factor = int(batch_size / 2)
|
|
|
|
|
messages.append(f"Mega batch factor is too large for the given batch size, clamping mega batch factor to: {mega_batch_factor}")
|
|
|
|
|
|
|
|
|
|
if batch_size % lines != 0:
|
|
|
|
|
nearest_slice = int(lines / batch_size) + 1
|
|
|
|
|
batch_size = int(lines / nearest_slice)
|
|
|
|
|
messages.append(f"Batch size not neatly divisible by dataset size, adjusting batch size to: {batch_size} ({nearest_slice} steps per epoch)")
|
|
|
|
|
|
|
|
|
|
if batch_size == 1 and mega_batch_factor != 1:
|
|
|
|
|
mega_batch_factor = 1
|
|
|
|
|
messages.append(f"Mega batch factor is too large for the given batch size, clamping mega batch factor to: {mega_batch_factor}")
|
|
|
|
|
elif batch_size / mega_batch_factor < 2:
|
|
|
|
|
mega_batch_factor = int(batch_size / 2)
|
|
|
|
|
messages.append(f"Mega batch factor is too large for the given batch size, clamping mega batch factor to: {mega_batch_factor}")
|
|
|
|
|
|
|
|
|
|
iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size)
|
|
|
|
|
|
|
|
|
|