|
|
@ -47,28 +47,28 @@ def run_generation(
|
|
|
|
):
|
|
|
|
):
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
sample, outputs, stats = generate(
|
|
|
|
sample, outputs, stats = generate(
|
|
|
|
text,
|
|
|
|
text=text,
|
|
|
|
delimiter,
|
|
|
|
delimiter=delimiter,
|
|
|
|
emotion,
|
|
|
|
emotion=emotion,
|
|
|
|
prompt,
|
|
|
|
prompt=prompt,
|
|
|
|
voice,
|
|
|
|
voice=voice,
|
|
|
|
mic_audio,
|
|
|
|
mic_audio=mic_audio,
|
|
|
|
voice_latents_chunks,
|
|
|
|
voice_latents_chunks=voice_latents_chunks,
|
|
|
|
seed,
|
|
|
|
seed=seed,
|
|
|
|
candidates,
|
|
|
|
candidates=candidates,
|
|
|
|
num_autoregressive_samples,
|
|
|
|
num_autoregressive_samples=num_autoregressive_samples,
|
|
|
|
diffusion_iterations,
|
|
|
|
diffusion_iterations=diffusion_iterations,
|
|
|
|
temperature,
|
|
|
|
temperature=temperature,
|
|
|
|
diffusion_sampler,
|
|
|
|
diffusion_sampler=diffusion_sampler,
|
|
|
|
breathing_room,
|
|
|
|
breathing_room=breathing_room,
|
|
|
|
cvvp_weight,
|
|
|
|
cvvp_weight=cvvp_weight,
|
|
|
|
top_p,
|
|
|
|
top_p=top_p,
|
|
|
|
diffusion_temperature,
|
|
|
|
diffusion_temperature=diffusion_temperature,
|
|
|
|
length_penalty,
|
|
|
|
length_penalty=length_penalty,
|
|
|
|
repetition_penalty,
|
|
|
|
repetition_penalty=repetition_penalty,
|
|
|
|
cond_free_k,
|
|
|
|
cond_free_k=cond_free_k,
|
|
|
|
experimental_checkboxes,
|
|
|
|
experimental_checkboxes=experimental_checkboxes,
|
|
|
|
progress
|
|
|
|
progress=progress
|
|
|
|
)
|
|
|
|
)
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
message = str(e)
|
|
|
|
message = str(e)
|
|
|
@ -180,7 +180,7 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
|
|
|
|
def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ):
|
|
|
|
def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ):
|
|
|
|
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress )
|
|
|
|
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress )
|
|
|
|
|
|
|
|
|
|
|
|
def save_training_settings_proxy( iterations, batch_size, learning_rate, print_rate, save_rate, voice ):
|
|
|
|
def save_training_settings_proxy( iterations, batch_size, learning_rate, learning_rate_schedule, mega_batch_factor, print_rate, save_rate, resume_path, voice ):
|
|
|
|
name = f"{voice}-finetune"
|
|
|
|
name = f"{voice}-finetune"
|
|
|
|
dataset_name = f"{voice}-train"
|
|
|
|
dataset_name = f"{voice}-train"
|
|
|
|
dataset_path = f"./training/{voice}/train.txt"
|
|
|
|
dataset_path = f"./training/{voice}/train.txt"
|
|
|
@ -190,13 +190,44 @@ def save_training_settings_proxy( iterations, batch_size, learning_rate, print_r
|
|
|
|
with open(dataset_path, 'r', encoding="utf-8") as f:
|
|
|
|
with open(dataset_path, 'r', encoding="utf-8") as f:
|
|
|
|
lines = len(f.readlines())
|
|
|
|
lines = len(f.readlines())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
messages = []
|
|
|
|
|
|
|
|
|
|
|
|
if batch_size > lines:
|
|
|
|
if batch_size > lines:
|
|
|
|
print("Batch size is larger than your dataset, clamping...")
|
|
|
|
|
|
|
|
batch_size = lines
|
|
|
|
batch_size = lines
|
|
|
|
|
|
|
|
messages.append(f"Batch size is larger than your dataset, clamping batch size to: {batch_size}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if batch_size / mega_batch_factor < 2:
|
|
|
|
|
|
|
|
mega_batch_factor = int(batch_size / 2)
|
|
|
|
|
|
|
|
messages.append(f"Mega batch factor is too large for the given batch size, clamping mega batch factor to: {mega_batch_factor}")
|
|
|
|
|
|
|
|
|
|
|
|
out_name = f"{voice}/train.yaml"
|
|
|
|
if iterations < print_rate:
|
|
|
|
|
|
|
|
print_rate = iterations
|
|
|
|
return save_training_settings(iterations, batch_size, learning_rate, print_rate, save_rate, name, dataset_name, dataset_path, validation_name, validation_path, out_name )
|
|
|
|
messages.append(f"Print rate is too small for the given iteration step, clamping print rate to: {print_rate}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if iterations < save_rate:
|
|
|
|
|
|
|
|
save_rate = iterations
|
|
|
|
|
|
|
|
messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {save_rate}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if resume_path and not os.path.exists(resume_path):
|
|
|
|
|
|
|
|
messages.append("Resume path specified, but does not exist. Disabling...")
|
|
|
|
|
|
|
|
resume_path = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
messages.append(save_training_settings(iterations,
|
|
|
|
|
|
|
|
batch_size=batch_size,
|
|
|
|
|
|
|
|
learning_rate=learning_rate,
|
|
|
|
|
|
|
|
learning_rate_schedule=learning_rate_schedule,
|
|
|
|
|
|
|
|
mega_batch_factor=mega_batch_factor,
|
|
|
|
|
|
|
|
print_rate=print_rate,
|
|
|
|
|
|
|
|
save_rate=save_rate,
|
|
|
|
|
|
|
|
name=name,
|
|
|
|
|
|
|
|
dataset_name=dataset_name,
|
|
|
|
|
|
|
|
dataset_path=dataset_path,
|
|
|
|
|
|
|
|
validation_name=validation_name,
|
|
|
|
|
|
|
|
validation_path=validation_path,
|
|
|
|
|
|
|
|
output_name=f"{voice}/train.yaml",
|
|
|
|
|
|
|
|
resume_path=resume_path,
|
|
|
|
|
|
|
|
))
|
|
|
|
|
|
|
|
return "\n".join(messages)
|
|
|
|
|
|
|
|
|
|
|
|
def update_voices():
|
|
|
|
def update_voices():
|
|
|
|
return (
|
|
|
|
return (
|
|
|
@ -326,8 +357,11 @@ def setup_gradio():
|
|
|
|
gr.Slider(label="Iterations", minimum=0, maximum=5000, value=500),
|
|
|
|
gr.Slider(label="Iterations", minimum=0, maximum=5000, value=500),
|
|
|
|
gr.Slider(label="Batch Size", minimum=2, maximum=128, value=64),
|
|
|
|
gr.Slider(label="Batch Size", minimum=2, maximum=128, value=64),
|
|
|
|
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
|
|
|
|
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
|
|
|
|
|
|
|
|
gr.Textbox(label="Learning Rate Schedule", placeholder="[ 200, 300, 400, 500 ]"),
|
|
|
|
|
|
|
|
gr.Slider(label="Mega Batch Factor", minimum=1, maximum=16, value=4),
|
|
|
|
gr.Number(label="Print Frequency", value=50),
|
|
|
|
gr.Number(label="Print Frequency", value=50),
|
|
|
|
gr.Number(label="Save Frequency", value=50),
|
|
|
|
gr.Number(label="Save Frequency", value=50),
|
|
|
|
|
|
|
|
gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"),
|
|
|
|
]
|
|
|
|
]
|
|
|
|
dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" )
|
|
|
|
dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" )
|
|
|
|
training_settings = training_settings + [
|
|
|
|
training_settings = training_settings + [
|
|
|
|