doing something completely unrelated had me realize it's 1000x easier to just base things in terms of epochs, and calculate iteratsions from there

This commit is contained in:
mrq 2023-02-19 20:22:03 +00:00
parent ec76676b16
commit 4694d622f4
2 changed files with 84 additions and 27 deletions

View File

@ -580,6 +580,63 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
return voice return voice
def calc_iterations( epochs, lines, batch_size ):
iterations = int(epochs * lines / float(batch_size))
return iterations
def schedule_learning_rate( iterations ):
schedule = [ 9, 18, 25, 33 ]
return [int(iterations * d) for d in schedule]
def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate_schedule, mega_batch_factor, print_rate, save_rate, resume_path, voice ):
name = f"{voice}-finetune"
dataset_name = f"{voice}-train"
dataset_path = f"./training/{voice}/train.txt"
validation_name = f"{voice}-val"
validation_path = f"./training/{voice}/train.txt"
with open(dataset_path, 'r', encoding="utf-8") as f:
lines = len(f.readlines())
messages = []
if batch_size > lines:
batch_size = lines
messages.append(f"Batch size is larger than your dataset, clamping batch size to: {batch_size}")
if batch_size / mega_batch_factor < 2:
mega_batch_factor = int(batch_size / 2)
messages.append(f"Mega batch factor is too large for the given batch size, clamping mega batch factor to: {mega_batch_factor}")
iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size)
messages.append(f"For {epochs} epochs with {lines} lines, iterating for {iterations} steps")
if iterations < print_rate:
print_rate = iterations
messages.append(f"Print rate is too small for the given iteration step, clamping print rate to: {print_rate}")
if iterations < save_rate:
save_rate = iterations
messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {save_rate}")
if resume_path and not os.path.exists(resume_path):
resume_path = None
messages.append("Resume path specified, but does not exist. Disabling...")
learning_rate_schedule = schedule_learning_rate( iterations / epochs ) # faster learning schedule compared to just passing lines / batch_size due to truncating
messages.append(f"Suggesting best learning rate schedule for iterations: {learning_rate_schedule}")
return (
batch_size,
learning_rate,
learning_rate_schedule,
mega_batch_factor,
print_rate,
save_rate,
resume_path,
messages
)
def save_training_settings( iterations=None, batch_size=None, learning_rate=None, learning_rate_schedule=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None ): def save_training_settings( iterations=None, batch_size=None, learning_rate=None, learning_rate_schedule=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None ):
settings = { settings = {
"iterations": iterations if iterations else 500, "iterations": iterations if iterations else 500,

View File

@ -180,7 +180,21 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ): def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ):
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress ) return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress )
def save_training_settings_proxy( iterations, batch_size, learning_rate, learning_rate_schedule, mega_batch_factor, print_rate, save_rate, resume_path, voice ): def optimize_training_settings_proxy( *args, **kwargs ):
tup = optimize_training_settings(*args, **kwargs)
return (
gr.update(value=tup[0]),
gr.update(value=tup[1]),
gr.update(value=tup[2]),
gr.update(value=tup[3]),
gr.update(value=tup[4]),
gr.update(value=tup[5]),
gr.update(value=tup[6]),
"\n".join(tup[7])
)
def save_training_settings_proxy( epochs, batch_size, learning_rate, learning_rate_schedule, mega_batch_factor, print_rate, save_rate, resume_path, voice ):
name = f"{voice}-finetune" name = f"{voice}-finetune"
dataset_name = f"{voice}-train" dataset_name = f"{voice}-train"
dataset_path = f"./training/{voice}/train.txt" dataset_path = f"./training/{voice}/train.txt"
@ -192,28 +206,11 @@ def save_training_settings_proxy( iterations, batch_size, learning_rate, learnin
messages = [] messages = []
if batch_size > lines: iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size)
batch_size = lines messages.append(f"For {epochs} epochs with {lines} lines, iterating for {iterations} steps")
messages.append(f"Batch size is larger than your dataset, clamping batch size to: {batch_size}")
if batch_size / mega_batch_factor < 2: messages.append(save_training_settings(
mega_batch_factor = int(batch_size / 2) iterations=iterations,
messages.append(f"Mega batch factor is too large for the given batch size, clamping mega batch factor to: {mega_batch_factor}")
if iterations < print_rate:
print_rate = iterations
messages.append(f"Print rate is too small for the given iteration step, clamping print rate to: {print_rate}")
if iterations < save_rate:
save_rate = iterations
messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {save_rate}")
if resume_path and not os.path.exists(resume_path):
messages.append("Resume path specified, but does not exist. Disabling...")
resume_path = None
messages.append(save_training_settings(iterations,
batch_size=batch_size, batch_size=batch_size,
learning_rate=learning_rate, learning_rate=learning_rate,
learning_rate_schedule=learning_rate_schedule, learning_rate_schedule=learning_rate_schedule,
@ -355,19 +352,17 @@ def setup_gradio():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
training_settings = [ training_settings = [
gr.Slider(label="Iterations", minimum=0, maximum=5000, value=500), gr.Slider(label="Epochs", minimum=0, maximum=500, value=10),
gr.Slider(label="Batch Size", minimum=2, maximum=128, value=64), gr.Slider(label="Batch Size", minimum=2, maximum=128, value=64),
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6), gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
gr.Textbox(label="Learning Rate Schedule", placeholder="[ 200, 300, 400, 500 ]"), gr.Textbox(label="Learning Rate Schedule", placeholder="[ 100, 200, 300, 400 ]"),
gr.Slider(label="Mega Batch Factor", minimum=1, maximum=16, value=4, step=1), gr.Slider(label="Mega Batch Factor", minimum=1, maximum=16, value=4, step=1),
gr.Number(label="Print Frequency", value=50), gr.Number(label="Print Frequency", value=50),
gr.Number(label="Save Frequency", value=50), gr.Number(label="Save Frequency", value=50),
gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"), gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"),
] ]
dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" ) dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" )
training_settings = training_settings + [ training_settings = training_settings + [ dataset_list ]
dataset_list
]
refresh_dataset_list = gr.Button(value="Refresh Dataset List") refresh_dataset_list = gr.Button(value="Refresh Dataset List")
""" """
training_settings = training_settings + [ training_settings = training_settings + [
@ -380,6 +375,7 @@ def setup_gradio():
""" """
with gr.Column(): with gr.Column():
save_yaml_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) save_yaml_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
optimize_yaml_button = gr.Button(value="Validate Training Configuration")
save_yaml_button = gr.Button(value="Save Training Configuration") save_yaml_button = gr.Button(value="Save Training Configuration")
with gr.Tab("Run Training"): with gr.Tab("Run Training"):
with gr.Row(): with gr.Row():
@ -591,6 +587,10 @@ def setup_gradio():
inputs=None, inputs=None,
outputs=dataset_list, outputs=dataset_list,
) )
optimize_yaml_button.click(optimize_training_settings_proxy,
inputs=training_settings,
outputs=training_settings[1:8] + [save_yaml_output] #console_output
)
save_yaml_button.click(save_training_settings_proxy, save_yaml_button.click(save_training_settings_proxy,
inputs=training_settings, inputs=training_settings,
outputs=save_yaml_output #console_output outputs=save_yaml_output #console_output