|
|
|
@ -205,7 +205,8 @@ def optimize_training_settings_proxy( *args, **kwargs ):
|
|
|
|
|
gr.update(value=tup[5]),
|
|
|
|
|
gr.update(value=tup[6]),
|
|
|
|
|
gr.update(value=tup[7]),
|
|
|
|
|
"\n".join(tup[8])
|
|
|
|
|
gr.update(value=tup[8]),
|
|
|
|
|
"\n".join(tup[9])
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def import_training_settings_proxy( voice ):
|
|
|
|
@ -247,11 +248,15 @@ def import_training_settings_proxy( voice ):
|
|
|
|
|
|
|
|
|
|
print_rate = int(config['logger']['print_freq'] / steps_per_iteration)
|
|
|
|
|
save_rate = int(config['logger']['save_checkpoint_freq'] / steps_per_iteration)
|
|
|
|
|
validation_rate = int(config['train']['val_freq'] / steps_per_iteration)
|
|
|
|
|
|
|
|
|
|
statedir = f'{outdir}/training_state/' # NOOO STOP MIXING YOUR CASES
|
|
|
|
|
half_p = config['fp16']
|
|
|
|
|
bnb = True
|
|
|
|
|
|
|
|
|
|
statedir = f'{outdir}/training_state/'
|
|
|
|
|
resumes = []
|
|
|
|
|
resume_path = None
|
|
|
|
|
source_model = None
|
|
|
|
|
source_model = get_halfp_model_path() if half_p else get_model_path('autoregressive.pth')
|
|
|
|
|
|
|
|
|
|
if "pretrain_model_gpt" in config['path']:
|
|
|
|
|
source_model = config['path']['pretrain_model_gpt']
|
|
|
|
@ -267,8 +272,6 @@ def import_training_settings_proxy( voice ):
|
|
|
|
|
messages.append(f"Latest resume found: {resume_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
half_p = config['fp16']
|
|
|
|
|
bnb = True
|
|
|
|
|
|
|
|
|
|
if "ext" in config and "bitsandbytes" in config["ext"]:
|
|
|
|
|
bnb = config["ext"]["bitsandbytes"]
|
|
|
|
@ -286,6 +289,7 @@ def import_training_settings_proxy( voice ):
|
|
|
|
|
gradient_accumulation_size,
|
|
|
|
|
print_rate,
|
|
|
|
|
save_rate,
|
|
|
|
|
validation_rate,
|
|
|
|
|
resume_path,
|
|
|
|
|
half_p,
|
|
|
|
|
bnb,
|
|
|
|
@ -295,7 +299,7 @@ def import_training_settings_proxy( voice ):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, workers, source_model, voice ):
|
|
|
|
|
def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, validation_rate, resume_path, half_p, bnb, workers, source_model, voice ):
|
|
|
|
|
name = f"{voice}-finetune"
|
|
|
|
|
dataset_name = f"{voice}-train"
|
|
|
|
|
dataset_path = f"./training/{voice}/train.txt"
|
|
|
|
@ -312,7 +316,7 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
|
|
|
|
|
|
|
|
|
|
print_rate = int(print_rate * iterations / epochs)
|
|
|
|
|
save_rate = int(save_rate * iterations / epochs)
|
|
|
|
|
validation_rate = save_rate
|
|
|
|
|
validation_rate = int(validation_rate * iterations / epochs)
|
|
|
|
|
|
|
|
|
|
if iterations % save_rate != 0:
|
|
|
|
|
adjustment = int(iterations / save_rate) * save_rate
|
|
|
|
@ -497,7 +501,9 @@ def setup_gradio():
|
|
|
|
|
gr.Textbox(label="Language", value="en"),
|
|
|
|
|
gr.Checkbox(label="Skip Already Transcribed", value=False)
|
|
|
|
|
]
|
|
|
|
|
prepare_dataset_button = gr.Button(value="Prepare")
|
|
|
|
|
transcribe_button = gr.Button(value="Transcribe")
|
|
|
|
|
validation_text_cull_size = gr.Number(label="Validation Text Length Cull Size", value=12, precision=0)
|
|
|
|
|
prepare_validation_button = gr.Button(value="Prepare Validation")
|
|
|
|
|
with gr.Column():
|
|
|
|
|
prepare_dataset_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
|
|
|
|
with gr.Tab("Generate Configuration"):
|
|
|
|
@ -524,6 +530,7 @@ def setup_gradio():
|
|
|
|
|
training_settings = training_settings + [
|
|
|
|
|
gr.Number(label="Print Frequency (in epochs)", value=5, precision=0),
|
|
|
|
|
gr.Number(label="Save Frequency (in epochs)", value=5, precision=0),
|
|
|
|
|
gr.Number(label="Validation Frequency (in epochs)", value=5, precision=0),
|
|
|
|
|
]
|
|
|
|
|
training_settings = training_settings + [
|
|
|
|
|
gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"),
|
|
|
|
@ -823,11 +830,19 @@ def setup_gradio():
|
|
|
|
|
],
|
|
|
|
|
outputs=training_output #console_output
|
|
|
|
|
)
|
|
|
|
|
prepare_dataset_button.click(
|
|
|
|
|
transcribe_button.click(
|
|
|
|
|
prepare_dataset_proxy,
|
|
|
|
|
inputs=dataset_settings,
|
|
|
|
|
outputs=prepare_dataset_output #console_output
|
|
|
|
|
)
|
|
|
|
|
prepare_validation_button.click(
|
|
|
|
|
prepare_validation_dataset,
|
|
|
|
|
inputs=[
|
|
|
|
|
dataset_settings[0],
|
|
|
|
|
validation_text_cull_size,
|
|
|
|
|
],
|
|
|
|
|
outputs=prepare_dataset_output #console_output
|
|
|
|
|
)
|
|
|
|
|
refresh_dataset_list.click(
|
|
|
|
|
lambda: gr.update(choices=get_dataset_list()),
|
|
|
|
|
inputs=None,
|
|
|
|
@ -835,11 +850,11 @@ def setup_gradio():
|
|
|
|
|
)
|
|
|
|
|
optimize_yaml_button.click(optimize_training_settings_proxy,
|
|
|
|
|
inputs=training_settings,
|
|
|
|
|
outputs=training_settings[1:9] + [save_yaml_output] #console_output
|
|
|
|
|
outputs=training_settings[1:10] + [save_yaml_output] #console_output
|
|
|
|
|
)
|
|
|
|
|
import_dataset_button.click(import_training_settings_proxy,
|
|
|
|
|
inputs=dataset_list_dropdown,
|
|
|
|
|
outputs=training_settings[:13] + [save_yaml_output] #console_output
|
|
|
|
|
outputs=training_settings[:14] + [save_yaml_output] #console_output
|
|
|
|
|
)
|
|
|
|
|
save_yaml_button.click(save_training_settings_proxy,
|
|
|
|
|
inputs=training_settings,
|
|
|
|
|