forked from mrq/ai-voice-cloning
doing something completely unrelated had me realize it's 1000x easier to just base things in terms of epochs, and calculate iteratsions from there
This commit is contained in:
parent
ec76676b16
commit
4694d622f4
57
src/utils.py
57
src/utils.py
|
@ -580,6 +580,63 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
|
||||||
|
|
||||||
return voice
|
return voice
|
||||||
|
|
||||||
|
def calc_iterations( epochs, lines, batch_size ):
|
||||||
|
iterations = int(epochs * lines / float(batch_size))
|
||||||
|
return iterations
|
||||||
|
|
||||||
|
def schedule_learning_rate( iterations ):
|
||||||
|
schedule = [ 9, 18, 25, 33 ]
|
||||||
|
return [int(iterations * d) for d in schedule]
|
||||||
|
|
||||||
|
def optimize_training_settings( epochs, batch_size, learning_rate, learning_rate_schedule, mega_batch_factor, print_rate, save_rate, resume_path, voice ):
|
||||||
|
name = f"{voice}-finetune"
|
||||||
|
dataset_name = f"{voice}-train"
|
||||||
|
dataset_path = f"./training/{voice}/train.txt"
|
||||||
|
validation_name = f"{voice}-val"
|
||||||
|
validation_path = f"./training/{voice}/train.txt"
|
||||||
|
|
||||||
|
with open(dataset_path, 'r', encoding="utf-8") as f:
|
||||||
|
lines = len(f.readlines())
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
if batch_size > lines:
|
||||||
|
batch_size = lines
|
||||||
|
messages.append(f"Batch size is larger than your dataset, clamping batch size to: {batch_size}")
|
||||||
|
|
||||||
|
if batch_size / mega_batch_factor < 2:
|
||||||
|
mega_batch_factor = int(batch_size / 2)
|
||||||
|
messages.append(f"Mega batch factor is too large for the given batch size, clamping mega batch factor to: {mega_batch_factor}")
|
||||||
|
|
||||||
|
iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size)
|
||||||
|
messages.append(f"For {epochs} epochs with {lines} lines, iterating for {iterations} steps")
|
||||||
|
|
||||||
|
if iterations < print_rate:
|
||||||
|
print_rate = iterations
|
||||||
|
messages.append(f"Print rate is too small for the given iteration step, clamping print rate to: {print_rate}")
|
||||||
|
|
||||||
|
if iterations < save_rate:
|
||||||
|
save_rate = iterations
|
||||||
|
messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {save_rate}")
|
||||||
|
|
||||||
|
if resume_path and not os.path.exists(resume_path):
|
||||||
|
resume_path = None
|
||||||
|
messages.append("Resume path specified, but does not exist. Disabling...")
|
||||||
|
|
||||||
|
learning_rate_schedule = schedule_learning_rate( iterations / epochs ) # faster learning schedule compared to just passing lines / batch_size due to truncating
|
||||||
|
messages.append(f"Suggesting best learning rate schedule for iterations: {learning_rate_schedule}")
|
||||||
|
|
||||||
|
return (
|
||||||
|
batch_size,
|
||||||
|
learning_rate,
|
||||||
|
learning_rate_schedule,
|
||||||
|
mega_batch_factor,
|
||||||
|
print_rate,
|
||||||
|
save_rate,
|
||||||
|
resume_path,
|
||||||
|
messages
|
||||||
|
)
|
||||||
|
|
||||||
def save_training_settings( iterations=None, batch_size=None, learning_rate=None, learning_rate_schedule=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None ):
|
def save_training_settings( iterations=None, batch_size=None, learning_rate=None, learning_rate_schedule=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None ):
|
||||||
settings = {
|
settings = {
|
||||||
"iterations": iterations if iterations else 500,
|
"iterations": iterations if iterations else 500,
|
||||||
|
|
54
src/webui.py
54
src/webui.py
|
@ -180,7 +180,21 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
|
||||||
def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ):
|
def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ):
|
||||||
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress )
|
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress )
|
||||||
|
|
||||||
def save_training_settings_proxy( iterations, batch_size, learning_rate, learning_rate_schedule, mega_batch_factor, print_rate, save_rate, resume_path, voice ):
|
def optimize_training_settings_proxy( *args, **kwargs ):
|
||||||
|
tup = optimize_training_settings(*args, **kwargs)
|
||||||
|
|
||||||
|
return (
|
||||||
|
gr.update(value=tup[0]),
|
||||||
|
gr.update(value=tup[1]),
|
||||||
|
gr.update(value=tup[2]),
|
||||||
|
gr.update(value=tup[3]),
|
||||||
|
gr.update(value=tup[4]),
|
||||||
|
gr.update(value=tup[5]),
|
||||||
|
gr.update(value=tup[6]),
|
||||||
|
"\n".join(tup[7])
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_training_settings_proxy( epochs, batch_size, learning_rate, learning_rate_schedule, mega_batch_factor, print_rate, save_rate, resume_path, voice ):
|
||||||
name = f"{voice}-finetune"
|
name = f"{voice}-finetune"
|
||||||
dataset_name = f"{voice}-train"
|
dataset_name = f"{voice}-train"
|
||||||
dataset_path = f"./training/{voice}/train.txt"
|
dataset_path = f"./training/{voice}/train.txt"
|
||||||
|
@ -192,28 +206,11 @@ def save_training_settings_proxy( iterations, batch_size, learning_rate, learnin
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
|
|
||||||
if batch_size > lines:
|
iterations = calc_iterations(epochs=epochs, lines=lines, batch_size=batch_size)
|
||||||
batch_size = lines
|
messages.append(f"For {epochs} epochs with {lines} lines, iterating for {iterations} steps")
|
||||||
messages.append(f"Batch size is larger than your dataset, clamping batch size to: {batch_size}")
|
|
||||||
|
|
||||||
|
|
||||||
if batch_size / mega_batch_factor < 2:
|
messages.append(save_training_settings(
|
||||||
mega_batch_factor = int(batch_size / 2)
|
iterations=iterations,
|
||||||
messages.append(f"Mega batch factor is too large for the given batch size, clamping mega batch factor to: {mega_batch_factor}")
|
|
||||||
|
|
||||||
if iterations < print_rate:
|
|
||||||
print_rate = iterations
|
|
||||||
messages.append(f"Print rate is too small for the given iteration step, clamping print rate to: {print_rate}")
|
|
||||||
|
|
||||||
if iterations < save_rate:
|
|
||||||
save_rate = iterations
|
|
||||||
messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {save_rate}")
|
|
||||||
|
|
||||||
if resume_path and not os.path.exists(resume_path):
|
|
||||||
messages.append("Resume path specified, but does not exist. Disabling...")
|
|
||||||
resume_path = None
|
|
||||||
|
|
||||||
messages.append(save_training_settings(iterations,
|
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
learning_rate_schedule=learning_rate_schedule,
|
learning_rate_schedule=learning_rate_schedule,
|
||||||
|
@ -355,19 +352,17 @@ def setup_gradio():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
training_settings = [
|
training_settings = [
|
||||||
gr.Slider(label="Iterations", minimum=0, maximum=5000, value=500),
|
gr.Slider(label="Epochs", minimum=0, maximum=500, value=10),
|
||||||
gr.Slider(label="Batch Size", minimum=2, maximum=128, value=64),
|
gr.Slider(label="Batch Size", minimum=2, maximum=128, value=64),
|
||||||
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
|
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
|
||||||
gr.Textbox(label="Learning Rate Schedule", placeholder="[ 200, 300, 400, 500 ]"),
|
gr.Textbox(label="Learning Rate Schedule", placeholder="[ 100, 200, 300, 400 ]"),
|
||||||
gr.Slider(label="Mega Batch Factor", minimum=1, maximum=16, value=4, step=1),
|
gr.Slider(label="Mega Batch Factor", minimum=1, maximum=16, value=4, step=1),
|
||||||
gr.Number(label="Print Frequency", value=50),
|
gr.Number(label="Print Frequency", value=50),
|
||||||
gr.Number(label="Save Frequency", value=50),
|
gr.Number(label="Save Frequency", value=50),
|
||||||
gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"),
|
gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"),
|
||||||
]
|
]
|
||||||
dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" )
|
dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" )
|
||||||
training_settings = training_settings + [
|
training_settings = training_settings + [ dataset_list ]
|
||||||
dataset_list
|
|
||||||
]
|
|
||||||
refresh_dataset_list = gr.Button(value="Refresh Dataset List")
|
refresh_dataset_list = gr.Button(value="Refresh Dataset List")
|
||||||
"""
|
"""
|
||||||
training_settings = training_settings + [
|
training_settings = training_settings + [
|
||||||
|
@ -380,6 +375,7 @@ def setup_gradio():
|
||||||
"""
|
"""
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
save_yaml_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
save_yaml_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||||
|
optimize_yaml_button = gr.Button(value="Validate Training Configuration")
|
||||||
save_yaml_button = gr.Button(value="Save Training Configuration")
|
save_yaml_button = gr.Button(value="Save Training Configuration")
|
||||||
with gr.Tab("Run Training"):
|
with gr.Tab("Run Training"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -591,6 +587,10 @@ def setup_gradio():
|
||||||
inputs=None,
|
inputs=None,
|
||||||
outputs=dataset_list,
|
outputs=dataset_list,
|
||||||
)
|
)
|
||||||
|
optimize_yaml_button.click(optimize_training_settings_proxy,
|
||||||
|
inputs=training_settings,
|
||||||
|
outputs=training_settings[1:8] + [save_yaml_output] #console_output
|
||||||
|
)
|
||||||
save_yaml_button.click(save_training_settings_proxy,
|
save_yaml_button.click(save_training_settings_proxy,
|
||||||
inputs=training_settings,
|
inputs=training_settings,
|
||||||
outputs=save_yaml_output #console_output
|
outputs=save_yaml_output #console_output
|
||||||
|
|
Loading…
Reference in New Issue
Block a user