forked from mrq/ai-voice-cloning
Simplified generating training YAML, cleaned it up, training output is cleaned up and will "autoscroll" (only show the last 8 lines, refer to console for a full trace if needed)
This commit is contained in:
parent
0dd5640a89
commit
843bfbfb96
20
src/utils.py
20
src/utils.py
|
@ -449,9 +449,9 @@ def run_training(config_path):
|
||||||
training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||||
buffer=[]
|
buffer=[]
|
||||||
for line in iter(training_process.stdout.readline, ""):
|
for line in iter(training_process.stdout.readline, ""):
|
||||||
buffer.append(line)
|
buffer.append(f'[{datetime.now().isoformat()}] {line}')
|
||||||
print(line[:-1])
|
print(f"[Training] {line[:-1]}")
|
||||||
yield "".join(buffer)
|
yield "".join(buffer[-8:])
|
||||||
|
|
||||||
training_process.stdout.close()
|
training_process.stdout.close()
|
||||||
return_code = training_process.wait()
|
return_code = training_process.wait()
|
||||||
|
@ -498,7 +498,7 @@ def setup_tortoise(restart=False):
|
||||||
print("TorToiSe initialized, ready for generation.")
|
print("TorToiSe initialized, ready for generation.")
|
||||||
return tts
|
return tts
|
||||||
|
|
||||||
def save_training_settings( batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None ):
|
def save_training_settings( batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None ):
|
||||||
settings = {
|
settings = {
|
||||||
"batch_size": batch_size if batch_size else 128,
|
"batch_size": batch_size if batch_size else 128,
|
||||||
"learning_rate": learning_rate if learning_rate else 1e-5,
|
"learning_rate": learning_rate if learning_rate else 1e-5,
|
||||||
|
@ -510,7 +510,11 @@ def save_training_settings( batch_size=None, learning_rate=None, print_rate=None
|
||||||
"validation_name": validation_name if validation_name else "finetune",
|
"validation_name": validation_name if validation_name else "finetune",
|
||||||
"validation_path": validation_path if validation_path else "./training/finetune/train.txt",
|
"validation_path": validation_path if validation_path else "./training/finetune/train.txt",
|
||||||
}
|
}
|
||||||
outfile = f'./training/{settings["name"]}.yaml'
|
|
||||||
|
if not output_name:
|
||||||
|
output_name = f'{settings["name"]}.yaml'
|
||||||
|
|
||||||
|
outfile = f'./training/{output_name}'
|
||||||
|
|
||||||
with open(f'./models/.template.yaml', 'r', encoding="utf-8") as f:
|
with open(f'./models/.template.yaml', 'r', encoding="utf-8") as f:
|
||||||
yaml = f.read()
|
yaml = f.read()
|
||||||
|
@ -724,9 +728,13 @@ def get_autoregressive_models(dir="./models/finetuned/"):
|
||||||
os.makedirs(dir, exist_ok=True)
|
os.makedirs(dir, exist_ok=True)
|
||||||
return [get_model_path('autoregressive.pth')] + sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ])
|
return [get_model_path('autoregressive.pth')] + sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 ])
|
||||||
|
|
||||||
|
def get_dataset_list(dir="./training/"):
|
||||||
|
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.txt" in os.listdir(os.path.join(dir, d)) ])
|
||||||
|
|
||||||
|
def get_training_list(dir="./training/"):
|
||||||
|
return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and len(os.listdir(os.path.join(dir, d))) > 0 and "train.yaml" in os.listdir(os.path.join(dir, d)) ])
|
||||||
|
|
||||||
def update_autoregressive_model(path_name):
|
def update_autoregressive_model(path_name):
|
||||||
|
|
||||||
global tts
|
global tts
|
||||||
if not tts:
|
if not tts:
|
||||||
raise Exception("TTS is uninitialized or still initializing...")
|
raise Exception("TTS is uninitialized or still initializing...")
|
||||||
|
|
42
src/webui.py
42
src/webui.py
|
@ -131,7 +131,7 @@ def get_training_configs():
|
||||||
return configs
|
return configs
|
||||||
|
|
||||||
def update_training_configs():
|
def update_training_configs():
|
||||||
return gr.update(choices=get_training_configs())
|
return gr.update(choices=get_training_list())
|
||||||
|
|
||||||
history_headers = {
|
history_headers = {
|
||||||
"Name": "",
|
"Name": "",
|
||||||
|
@ -201,6 +201,24 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
|
||||||
def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ):
|
def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ):
|
||||||
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress )
|
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress )
|
||||||
|
|
||||||
|
def save_training_settings_proxy( batch_size, learning_rate, print_rate, save_rate, voice ):
|
||||||
|
name = f"{voice}-finetune"
|
||||||
|
dataset_name = f"{voice}-train"
|
||||||
|
dataset_path = f"./training/{voice}/train.txt"
|
||||||
|
validation_name = f"{voice}-val"
|
||||||
|
validation_path = f"./training/{voice}/train.txt"
|
||||||
|
|
||||||
|
with open(dataset_path, 'r', encoding="utf-8") as f:
|
||||||
|
lines = len(f.readlines())
|
||||||
|
|
||||||
|
if batch_size > lines:
|
||||||
|
print("Batch size is larger than your dataset, clamping...")
|
||||||
|
batch_size = lines
|
||||||
|
|
||||||
|
out_name = f"{voice}/train.yaml"
|
||||||
|
|
||||||
|
return save_training_settings(batch_size, learning_rate, print_rate, save_rate, name, dataset_name, dataset_path, validation_name, validation_path, out_name )
|
||||||
|
|
||||||
def update_voices():
|
def update_voices():
|
||||||
return (
|
return (
|
||||||
gr.Dropdown.update(choices=get_voice_list()),
|
gr.Dropdown.update(choices=get_voice_list()),
|
||||||
|
@ -333,6 +351,12 @@ def setup_gradio():
|
||||||
gr.Number(label="Print Frequency", value=50),
|
gr.Number(label="Print Frequency", value=50),
|
||||||
gr.Number(label="Save Frequency", value=50),
|
gr.Number(label="Save Frequency", value=50),
|
||||||
]
|
]
|
||||||
|
dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" )
|
||||||
|
training_settings = training_settings + [
|
||||||
|
dataset_list
|
||||||
|
]
|
||||||
|
refresh_dataset_list = gr.Button(value="Refresh Dataset List")
|
||||||
|
"""
|
||||||
training_settings = training_settings + [
|
training_settings = training_settings + [
|
||||||
gr.Textbox(label="Training Name", placeholder="finetune"),
|
gr.Textbox(label="Training Name", placeholder="finetune"),
|
||||||
gr.Textbox(label="Dataset Name", placeholder="finetune"),
|
gr.Textbox(label="Dataset Name", placeholder="finetune"),
|
||||||
|
@ -340,13 +364,14 @@ def setup_gradio():
|
||||||
gr.Textbox(label="Validation Name", placeholder="finetune"),
|
gr.Textbox(label="Validation Name", placeholder="finetune"),
|
||||||
gr.Textbox(label="Validation Path", placeholder="./training/finetune/train.txt"),
|
gr.Textbox(label="Validation Path", placeholder="./training/finetune/train.txt"),
|
||||||
]
|
]
|
||||||
|
"""
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
save_yaml_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
save_yaml_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||||
save_yaml_button = gr.Button(value="Save Training Configuration")
|
save_yaml_button = gr.Button(value="Save Training Configuration")
|
||||||
with gr.Tab("Run Training"):
|
with gr.Tab("Run Training"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_configs())
|
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_list())
|
||||||
refresh_configs = gr.Button(value="Refresh Configurations")
|
refresh_configs = gr.Button(value="Refresh Configurations")
|
||||||
start_training_button = gr.Button(value="Train")
|
start_training_button = gr.Button(value="Train")
|
||||||
stop_training_button = gr.Button(value="Stop")
|
stop_training_button = gr.Button(value="Stop")
|
||||||
|
@ -524,7 +549,11 @@ def setup_gradio():
|
||||||
outputs=input_settings
|
outputs=input_settings
|
||||||
)
|
)
|
||||||
|
|
||||||
refresh_configs.click(update_training_configs,inputs=None,outputs=training_configs)
|
refresh_configs.click(
|
||||||
|
lambda: gr.update(choices=get_training_list()),
|
||||||
|
inputs=None,
|
||||||
|
outputs=training_configs
|
||||||
|
)
|
||||||
start_training_button.click(run_training,
|
start_training_button.click(run_training,
|
||||||
inputs=training_configs,
|
inputs=training_configs,
|
||||||
outputs=training_output #console_output
|
outputs=training_output #console_output
|
||||||
|
@ -538,7 +567,12 @@ def setup_gradio():
|
||||||
inputs=dataset_settings,
|
inputs=dataset_settings,
|
||||||
outputs=prepare_dataset_output #console_output
|
outputs=prepare_dataset_output #console_output
|
||||||
)
|
)
|
||||||
save_yaml_button.click(save_training_settings,
|
refresh_dataset_list.click(
|
||||||
|
lambda: gr.update(choices=get_dataset_list()),
|
||||||
|
inputs=None,
|
||||||
|
outputs=dataset_list,
|
||||||
|
)
|
||||||
|
save_yaml_button.click(save_training_settings_proxy,
|
||||||
inputs=training_settings,
|
inputs=training_settings,
|
||||||
outputs=save_yaml_output #console_output
|
outputs=save_yaml_output #console_output
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user