Simplified generating training YAML, cleaned it up, training output is cleaned up and will "autoscroll" (only show the last 8 lines, refer to console for a full trace if needed)

This commit is contained in:
mrq 2023-02-18 14:51:00 +00:00
parent 0dd5640a89
commit 843bfbfb96
2 changed files with 52 additions and 10 deletions

View File

@ -449,9 +449,9 @@ def run_training(config_path):
training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
buffer=[]
for line in iter(training_process.stdout.readline, ""):
buffer.append(line)
print(line[:-1])
yield "".join(buffer)
buffer.append(f'[{datetime.now().isoformat()}] {line}')
print(f"[Training] {line[:-1]}")
yield "".join(buffer[-8:])
training_process.stdout.close()
return_code = training_process.wait()
@ -498,7 +498,7 @@ def setup_tortoise(restart=False):
print("TorToiSe initialized, ready for generation.")
return tts
def save_training_settings( batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None ):
def save_training_settings( batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None ):
settings = {
"batch_size": batch_size if batch_size else 128,
"learning_rate": learning_rate if learning_rate else 1e-5,
@ -510,7 +510,11 @@ def save_training_settings( batch_size=None, learning_rate=None, print_rate=None
"validation_name": validation_name if validation_name else "finetune",
"validation_path": validation_path if validation_path else "./training/finetune/train.txt",
}
outfile = f'./training/{settings["name"]}.yaml'
if not output_name:
output_name = f'{settings["name"]}.yaml'
outfile = f'./training/{output_name}'
with open(f'./models/.template.yaml', 'r', encoding="utf-8") as f:
yaml = f.read()
@ -724,9 +728,13 @@ def get_autoregressive_models(dir="./models/finetuned/"):
os.makedirs(dir, exist_ok=True)
return [get_model_path('autoregressive.pth')] + sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ])
def get_dataset_list(dir="./training/"):
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.txt" in os.listdir(os.path.join(dir, d)) ])
def get_training_list(dir="./training/"):
return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.yaml" in os.listdir(os.path.join(dir, d)) ])
def update_autoregressive_model(path_name):
global tts
if not tts:
raise Exception("TTS is uninitialized or still initializing...")

View File

@ -131,7 +131,7 @@ def get_training_configs():
return configs
def update_training_configs():
return gr.update(choices=get_training_configs())
return gr.update(choices=get_training_list())
history_headers = {
"Name": "",
@ -201,6 +201,24 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ):
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress )
def save_training_settings_proxy( batch_size, learning_rate, print_rate, save_rate, voice ):
name = f"{voice}-finetune"
dataset_name = f"{voice}-train"
dataset_path = f"./training/{voice}/train.txt"
validation_name = f"{voice}-val"
validation_path = f"./training/{voice}/train.txt"
with open(dataset_path, 'r', encoding="utf-8") as f:
lines = len(f.readlines())
if batch_size > lines:
print("Batch size is larger than your dataset, clamping...")
batch_size = lines
out_name = f"{voice}/train.yaml"
return save_training_settings(batch_size, learning_rate, print_rate, save_rate, name, dataset_name, dataset_path, validation_name, validation_path, out_name )
def update_voices():
return (
gr.Dropdown.update(choices=get_voice_list()),
@ -333,6 +351,12 @@ def setup_gradio():
gr.Number(label="Print Frequency", value=50),
gr.Number(label="Save Frequency", value=50),
]
dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" )
training_settings = training_settings + [
dataset_list
]
refresh_dataset_list = gr.Button(value="Refresh Dataset List")
"""
training_settings = training_settings + [
gr.Textbox(label="Training Name", placeholder="finetune"),
gr.Textbox(label="Dataset Name", placeholder="finetune"),
@ -340,13 +364,14 @@ def setup_gradio():
gr.Textbox(label="Validation Name", placeholder="finetune"),
gr.Textbox(label="Validation Path", placeholder="./training/finetune/train.txt"),
]
"""
with gr.Column():
save_yaml_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
save_yaml_button = gr.Button(value="Save Training Configuration")
with gr.Tab("Run Training"):
with gr.Row():
with gr.Column():
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_configs())
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_list())
refresh_configs = gr.Button(value="Refresh Configurations")
start_training_button = gr.Button(value="Train")
stop_training_button = gr.Button(value="Stop")
@ -524,7 +549,11 @@ def setup_gradio():
outputs=input_settings
)
refresh_configs.click(update_training_configs,inputs=None,outputs=training_configs)
refresh_configs.click(
lambda: gr.update(choices=get_training_list()),
inputs=None,
outputs=training_configs
)
start_training_button.click(run_training,
inputs=training_configs,
outputs=training_output #console_output
@ -538,7 +567,12 @@ def setup_gradio():
inputs=dataset_settings,
outputs=prepare_dataset_output #console_output
)
save_yaml_button.click(save_training_settings,
refresh_dataset_list.click(
lambda: gr.update(choices=get_dataset_list()),
inputs=None,
outputs=dataset_list,
)
save_yaml_button.click(save_training_settings_proxy,
inputs=training_settings,
outputs=save_yaml_output #console_output
)