forked from mrq/ai-voice-cloning
Simplified generating training YAML, cleaned it up, training output is cleaned up and will "autoscroll" (only show the last 8 lines, refer to console for a full trace if needed)
This commit is contained in:
parent
0dd5640a89
commit
843bfbfb96
20
src/utils.py
20
src/utils.py
|
@ -449,9 +449,9 @@ def run_training(config_path):
|
|||
training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||
buffer=[]
|
||||
for line in iter(training_process.stdout.readline, ""):
|
||||
buffer.append(line)
|
||||
print(line[:-1])
|
||||
yield "".join(buffer)
|
||||
buffer.append(f'[{datetime.now().isoformat()}] {line}')
|
||||
print(f"[Training] {line[:-1]}")
|
||||
yield "".join(buffer[-8:])
|
||||
|
||||
training_process.stdout.close()
|
||||
return_code = training_process.wait()
|
||||
|
@ -498,7 +498,7 @@ def setup_tortoise(restart=False):
|
|||
print("TorToiSe initialized, ready for generation.")
|
||||
return tts
|
||||
|
||||
def save_training_settings( batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None ):
|
||||
def save_training_settings( batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None ):
|
||||
settings = {
|
||||
"batch_size": batch_size if batch_size else 128,
|
||||
"learning_rate": learning_rate if learning_rate else 1e-5,
|
||||
|
@ -510,7 +510,11 @@ def save_training_settings( batch_size=None, learning_rate=None, print_rate=None
|
|||
"validation_name": validation_name if validation_name else "finetune",
|
||||
"validation_path": validation_path if validation_path else "./training/finetune/train.txt",
|
||||
}
|
||||
outfile = f'./training/{settings["name"]}.yaml'
|
||||
|
||||
if not output_name:
|
||||
output_name = f'{settings["name"]}.yaml'
|
||||
|
||||
outfile = f'./training/{output_name}'
|
||||
|
||||
with open(f'./models/.template.yaml', 'r', encoding="utf-8") as f:
|
||||
yaml = f.read()
|
||||
|
@ -724,9 +728,13 @@ def get_autoregressive_models(dir="./models/finetuned/"):
|
|||
os.makedirs(dir, exist_ok=True)
|
||||
return [get_model_path('autoregressive.pth')] + sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ])
|
||||
|
||||
def get_dataset_list(dir="./training/"):
|
||||
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.txt" in os.listdir(os.path.join(dir, d)) ])
|
||||
|
||||
def get_training_list(dir="./training/"):
|
||||
return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.yaml" in os.listdir(os.path.join(dir, d)) ])
|
||||
|
||||
def update_autoregressive_model(path_name):
|
||||
|
||||
global tts
|
||||
if not tts:
|
||||
raise Exception("TTS is uninitialized or still initializing...")
|
||||
|
|
42
src/webui.py
42
src/webui.py
|
@ -131,7 +131,7 @@ def get_training_configs():
|
|||
return configs
|
||||
|
||||
def update_training_configs():
|
||||
return gr.update(choices=get_training_configs())
|
||||
return gr.update(choices=get_training_list())
|
||||
|
||||
history_headers = {
|
||||
"Name": "",
|
||||
|
@ -201,6 +201,24 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
|
|||
def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ):
|
||||
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress )
|
||||
|
||||
def save_training_settings_proxy( batch_size, learning_rate, print_rate, save_rate, voice ):
|
||||
name = f"{voice}-finetune"
|
||||
dataset_name = f"{voice}-train"
|
||||
dataset_path = f"./training/{voice}/train.txt"
|
||||
validation_name = f"{voice}-val"
|
||||
validation_path = f"./training/{voice}/train.txt"
|
||||
|
||||
with open(dataset_path, 'r', encoding="utf-8") as f:
|
||||
lines = len(f.readlines())
|
||||
|
||||
if batch_size > lines:
|
||||
print("Batch size is larger than your dataset, clamping...")
|
||||
batch_size = lines
|
||||
|
||||
out_name = f"{voice}/train.yaml"
|
||||
|
||||
return save_training_settings(batch_size, learning_rate, print_rate, save_rate, name, dataset_name, dataset_path, validation_name, validation_path, out_name )
|
||||
|
||||
def update_voices():
|
||||
return (
|
||||
gr.Dropdown.update(choices=get_voice_list()),
|
||||
|
@ -333,6 +351,12 @@ def setup_gradio():
|
|||
gr.Number(label="Print Frequency", value=50),
|
||||
gr.Number(label="Save Frequency", value=50),
|
||||
]
|
||||
dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" )
|
||||
training_settings = training_settings + [
|
||||
dataset_list
|
||||
]
|
||||
refresh_dataset_list = gr.Button(value="Refresh Dataset List")
|
||||
"""
|
||||
training_settings = training_settings + [
|
||||
gr.Textbox(label="Training Name", placeholder="finetune"),
|
||||
gr.Textbox(label="Dataset Name", placeholder="finetune"),
|
||||
|
@ -340,13 +364,14 @@ def setup_gradio():
|
|||
gr.Textbox(label="Validation Name", placeholder="finetune"),
|
||||
gr.Textbox(label="Validation Path", placeholder="./training/finetune/train.txt"),
|
||||
]
|
||||
"""
|
||||
with gr.Column():
|
||||
save_yaml_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||
save_yaml_button = gr.Button(value="Save Training Configuration")
|
||||
with gr.Tab("Run Training"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_configs())
|
||||
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_list())
|
||||
refresh_configs = gr.Button(value="Refresh Configurations")
|
||||
start_training_button = gr.Button(value="Train")
|
||||
stop_training_button = gr.Button(value="Stop")
|
||||
|
@ -524,7 +549,11 @@ def setup_gradio():
|
|||
outputs=input_settings
|
||||
)
|
||||
|
||||
refresh_configs.click(update_training_configs,inputs=None,outputs=training_configs)
|
||||
refresh_configs.click(
|
||||
lambda: gr.update(choices=get_training_list()),
|
||||
inputs=None,
|
||||
outputs=training_configs
|
||||
)
|
||||
start_training_button.click(run_training,
|
||||
inputs=training_configs,
|
||||
outputs=training_output #console_output
|
||||
|
@ -538,7 +567,12 @@ def setup_gradio():
|
|||
inputs=dataset_settings,
|
||||
outputs=prepare_dataset_output #console_output
|
||||
)
|
||||
save_yaml_button.click(save_training_settings,
|
||||
refresh_dataset_list.click(
|
||||
lambda: gr.update(choices=get_dataset_list()),
|
||||
inputs=None,
|
||||
outputs=dataset_list,
|
||||
)
|
||||
save_yaml_button.click(save_training_settings_proxy,
|
||||
inputs=training_settings,
|
||||
outputs=save_yaml_output #console_output
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user