1
0
Fork 0

Simplified generating training YAML, cleaned it up, training output is cleaned up and will "autoscroll" (only show the last 8 lines, refer to console for a full trace if needed)

master
mrq 2023-02-18 14:51:00 +07:00
parent 0dd5640a89
commit 843bfbfb96
2 changed files with 52 additions and 10 deletions

@ -449,9 +449,9 @@ def run_training(config_path):
training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
buffer=[] buffer=[]
for line in iter(training_process.stdout.readline, ""): for line in iter(training_process.stdout.readline, ""):
buffer.append(line) buffer.append(f'[{datetime.now().isoformat()}] {line}')
print(line[:-1]) print(f"[Training] {line[:-1]}")
yield "".join(buffer) yield "".join(buffer[-8:])
training_process.stdout.close() training_process.stdout.close()
return_code = training_process.wait() return_code = training_process.wait()
@ -498,7 +498,7 @@ def setup_tortoise(restart=False):
print("TorToiSe initialized, ready for generation.") print("TorToiSe initialized, ready for generation.")
return tts return tts
def save_training_settings( batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None ): def save_training_settings( batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None ):
settings = { settings = {
"batch_size": batch_size if batch_size else 128, "batch_size": batch_size if batch_size else 128,
"learning_rate": learning_rate if learning_rate else 1e-5, "learning_rate": learning_rate if learning_rate else 1e-5,
@ -510,7 +510,11 @@ def save_training_settings( batch_size=None, learning_rate=None, print_rate=None
"validation_name": validation_name if validation_name else "finetune", "validation_name": validation_name if validation_name else "finetune",
"validation_path": validation_path if validation_path else "./training/finetune/train.txt", "validation_path": validation_path if validation_path else "./training/finetune/train.txt",
} }
outfile = f'./training/{settings["name"]}.yaml'
if not output_name:
output_name = f'{settings["name"]}.yaml'
outfile = f'./training/{output_name}'
with open(f'./models/.template.yaml', 'r', encoding="utf-8") as f: with open(f'./models/.template.yaml', 'r', encoding="utf-8") as f:
yaml = f.read() yaml = f.read()
@ -724,9 +728,13 @@ def get_autoregressive_models(dir="./models/finetuned/"):
os.makedirs(dir, exist_ok=True) os.makedirs(dir, exist_ok=True)
return [get_model_path('autoregressive.pth')] + sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ]) return [get_model_path('autoregressive.pth')] + sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ])
def get_dataset_list(dir="./training/"):
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.txt" in os.listdir(os.path.join(dir, d)) ])
def update_autoregressive_model(path_name): def get_training_list(dir="./training/"):
return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.yaml" in os.listdir(os.path.join(dir, d)) ])
def update_autoregressive_model(path_name):
global tts global tts
if not tts: if not tts:
raise Exception("TTS is uninitialized or still initializing...") raise Exception("TTS is uninitialized or still initializing...")

@ -131,7 +131,7 @@ def get_training_configs():
return configs return configs
def update_training_configs(): def update_training_configs():
return gr.update(choices=get_training_configs()) return gr.update(choices=get_training_list())
history_headers = { history_headers = {
"Name": "", "Name": "",
@ -201,6 +201,24 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ): def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ):
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress ) return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress )
def save_training_settings_proxy( batch_size, learning_rate, print_rate, save_rate, voice ):
name = f"{voice}-finetune"
dataset_name = f"{voice}-train"
dataset_path = f"./training/{voice}/train.txt"
validation_name = f"{voice}-val"
validation_path = f"./training/{voice}/train.txt"
with open(dataset_path, 'r', encoding="utf-8") as f:
lines = len(f.readlines())
if batch_size > lines:
print("Batch size is larger than your dataset, clamping...")
batch_size = lines
out_name = f"{voice}/train.yaml"
return save_training_settings(batch_size, learning_rate, print_rate, save_rate, name, dataset_name, dataset_path, validation_name, validation_path, out_name )
def update_voices(): def update_voices():
return ( return (
gr.Dropdown.update(choices=get_voice_list()), gr.Dropdown.update(choices=get_voice_list()),
@ -333,6 +351,12 @@ def setup_gradio():
gr.Number(label="Print Frequency", value=50), gr.Number(label="Print Frequency", value=50),
gr.Number(label="Save Frequency", value=50), gr.Number(label="Save Frequency", value=50),
] ]
dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" )
training_settings = training_settings + [
dataset_list
]
refresh_dataset_list = gr.Button(value="Refresh Dataset List")
"""
training_settings = training_settings + [ training_settings = training_settings + [
gr.Textbox(label="Training Name", placeholder="finetune"), gr.Textbox(label="Training Name", placeholder="finetune"),
gr.Textbox(label="Dataset Name", placeholder="finetune"), gr.Textbox(label="Dataset Name", placeholder="finetune"),
@ -340,13 +364,14 @@ def setup_gradio():
gr.Textbox(label="Validation Name", placeholder="finetune"), gr.Textbox(label="Validation Name", placeholder="finetune"),
gr.Textbox(label="Validation Path", placeholder="./training/finetune/train.txt"), gr.Textbox(label="Validation Path", placeholder="./training/finetune/train.txt"),
] ]
"""
with gr.Column(): with gr.Column():
save_yaml_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) save_yaml_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
save_yaml_button = gr.Button(value="Save Training Configuration") save_yaml_button = gr.Button(value="Save Training Configuration")
with gr.Tab("Run Training"): with gr.Tab("Run Training"):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_configs()) training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_list())
refresh_configs = gr.Button(value="Refresh Configurations") refresh_configs = gr.Button(value="Refresh Configurations")
start_training_button = gr.Button(value="Train") start_training_button = gr.Button(value="Train")
stop_training_button = gr.Button(value="Stop") stop_training_button = gr.Button(value="Stop")
@ -524,7 +549,11 @@ def setup_gradio():
outputs=input_settings outputs=input_settings
) )
refresh_configs.click(update_training_configs,inputs=None,outputs=training_configs) refresh_configs.click(
lambda: gr.update(choices=get_training_list()),
inputs=None,
outputs=training_configs
)
start_training_button.click(run_training, start_training_button.click(run_training,
inputs=training_configs, inputs=training_configs,
outputs=training_output #console_output outputs=training_output #console_output
@ -538,7 +567,12 @@ def setup_gradio():
inputs=dataset_settings, inputs=dataset_settings,
outputs=prepare_dataset_output #console_output outputs=prepare_dataset_output #console_output
) )
save_yaml_button.click(save_training_settings, refresh_dataset_list.click(
lambda: gr.update(choices=get_dataset_list()),
inputs=None,
outputs=dataset_list,
)
save_yaml_button.click(save_training_settings_proxy,
inputs=training_settings, inputs=training_settings,
outputs=save_yaml_output #console_output outputs=save_yaml_output #console_output
) )