forked from mrq/ai-voice-cloning
added more safeties and parameters to training yaml generator, I think I tested it extensively enough
This commit is contained in:
parent
f4e82fcf08
commit
092dd7b2d7
|
@ -114,19 +114,19 @@ networks:
|
|||
#only_alignment_head: False # uv3/4
|
||||
|
||||
path:
|
||||
pretrain_model_gpt: './models/tortoise/autoregressive.pth' # CHANGEME: copy this from tortoise cache
|
||||
${pretrain_model_gpt}
|
||||
strict_load: true
|
||||
#resume_state: ./models/tortoise/train_imgnet_vqvae_stage1/training_state/0.state # <-- Set this to resume from a previous training state.
|
||||
${resume_state}
|
||||
|
||||
# afaik all units here are measured in **steps** (i.e. one batch of batch_size is 1 unit)
|
||||
train: # CHANGEME: ALL OF THESE PARAMETERS SHOULD BE EXPERIMENTED WITH
|
||||
niter: ${iterations}
|
||||
warmup_iter: -1
|
||||
mega_batch_factor: 4 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8].
|
||||
val_freq: 500
|
||||
mega_batch_factor: ${mega_batch_factor} # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8].
|
||||
val_freq: ${iterations}
|
||||
|
||||
default_lr_scheme: MultiStepLR
|
||||
gen_lr_steps: [500, 1000, 1400, 1800] #[50000, 100000, 140000, 180000]
|
||||
gen_lr_steps: ${gen_lr_steps} #[50000, 100000, 140000, 180000]
|
||||
lr_gamma: 0.5
|
||||
|
||||
eval:
|
||||
|
|
12
src/utils.py
12
src/utils.py
|
@ -580,11 +580,13 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
|
|||
|
||||
return voice
|
||||
|
||||
def save_training_settings( iterations=None, batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None ):
|
||||
def save_training_settings( iterations=None, batch_size=None, learning_rate=None, learning_rate_schedule=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None ):
|
||||
settings = {
|
||||
"iterations": iterations if iterations else 500,
|
||||
"batch_size": batch_size if batch_size else 64,
|
||||
"learning_rate": learning_rate if learning_rate else 1e-5,
|
||||
"gen_lr_steps": learning_rate_schedule if learning_rate_schedule else [ 200, 300, 400, 500 ],
|
||||
"mega_batch_factor": mega_batch_factor if mega_batch_factor else 4,
|
||||
"print_rate": print_rate if print_rate else 50,
|
||||
"save_rate": save_rate if save_rate else 50,
|
||||
"name": name if name else "finetune",
|
||||
|
@ -592,19 +594,25 @@ def save_training_settings( iterations=None, batch_size=None, learning_rate=None
|
|||
"dataset_path": dataset_path if dataset_path else "./training/finetune/train.txt",
|
||||
"validation_name": validation_name if validation_name else "finetune",
|
||||
"validation_path": validation_path if validation_path else "./training/finetune/train.txt",
|
||||
|
||||
'resume_state': f"resume_state: '{resume_path}'" if resume_path else f"# resume_state: './training/{name if name else 'finetune'}/training_state/#.state'",
|
||||
'pretrain_model_gpt': "pretrain_model_gpt: './models/tortoise/autoregressive.pth'" if not resume_path else "# pretrain_model_gpt: './models/tortoise/autoregressive.pth'"
|
||||
}
|
||||
|
||||
if not output_name:
|
||||
output_name = f'{settings["name"]}.yaml'
|
||||
|
||||
outfile = f'./training/{output_name}'
|
||||
|
||||
with open(f'./models/.template.yaml', 'r', encoding="utf-8") as f:
|
||||
yaml = f.read()
|
||||
|
||||
# i could just load and edit the YAML directly, but this is easier, as I don't need to bother with path traversals
|
||||
for k in settings:
|
||||
if settings[k] is None:
|
||||
continue
|
||||
yaml = yaml.replace(f"${{{k}}}", str(settings[k]))
|
||||
|
||||
outfile = f'./training/{output_name}'
|
||||
with open(outfile, 'w', encoding="utf-8") as f:
|
||||
f.write(yaml)
|
||||
|
||||
|
|
86
src/webui.py
86
src/webui.py
|
@ -47,28 +47,28 @@ def run_generation(
|
|||
):
|
||||
try:
|
||||
sample, outputs, stats = generate(
|
||||
text,
|
||||
delimiter,
|
||||
emotion,
|
||||
prompt,
|
||||
voice,
|
||||
mic_audio,
|
||||
voice_latents_chunks,
|
||||
seed,
|
||||
candidates,
|
||||
num_autoregressive_samples,
|
||||
diffusion_iterations,
|
||||
temperature,
|
||||
diffusion_sampler,
|
||||
breathing_room,
|
||||
cvvp_weight,
|
||||
top_p,
|
||||
diffusion_temperature,
|
||||
length_penalty,
|
||||
repetition_penalty,
|
||||
cond_free_k,
|
||||
experimental_checkboxes,
|
||||
progress
|
||||
text=text,
|
||||
delimiter=delimiter,
|
||||
emotion=emotion,
|
||||
prompt=prompt,
|
||||
voice=voice,
|
||||
mic_audio=mic_audio,
|
||||
voice_latents_chunks=voice_latents_chunks,
|
||||
seed=seed,
|
||||
candidates=candidates,
|
||||
num_autoregressive_samples=num_autoregressive_samples,
|
||||
diffusion_iterations=diffusion_iterations,
|
||||
temperature=temperature,
|
||||
diffusion_sampler=diffusion_sampler,
|
||||
breathing_room=breathing_room,
|
||||
cvvp_weight=cvvp_weight,
|
||||
top_p=top_p,
|
||||
diffusion_temperature=diffusion_temperature,
|
||||
length_penalty=length_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
cond_free_k=cond_free_k,
|
||||
experimental_checkboxes=experimental_checkboxes,
|
||||
progress=progress
|
||||
)
|
||||
except Exception as e:
|
||||
message = str(e)
|
||||
|
@ -180,7 +180,7 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
|
|||
def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ):
|
||||
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress )
|
||||
|
||||
def save_training_settings_proxy( iterations, batch_size, learning_rate, print_rate, save_rate, voice ):
|
||||
def save_training_settings_proxy( iterations, batch_size, learning_rate, learning_rate_schedule, mega_batch_factor, print_rate, save_rate, resume_path, voice ):
|
||||
name = f"{voice}-finetune"
|
||||
dataset_name = f"{voice}-train"
|
||||
dataset_path = f"./training/{voice}/train.txt"
|
||||
|
@ -190,13 +190,44 @@ def save_training_settings_proxy( iterations, batch_size, learning_rate, print_r
|
|||
with open(dataset_path, 'r', encoding="utf-8") as f:
|
||||
lines = len(f.readlines())
|
||||
|
||||
messages = []
|
||||
|
||||
if batch_size > lines:
|
||||
print("Batch size is larger than your dataset, clamping...")
|
||||
batch_size = lines
|
||||
messages.append(f"Batch size is larger than your dataset, clamping batch size to: {batch_size}")
|
||||
|
||||
if batch_size / mega_batch_factor < 2:
|
||||
mega_batch_factor = int(batch_size / 2)
|
||||
messages.append(f"Mega batch factor is too large for the given batch size, clamping mega batch factor to: {mega_batch_factor}")
|
||||
|
||||
out_name = f"{voice}/train.yaml"
|
||||
if iterations < print_rate:
|
||||
print_rate = iterations
|
||||
messages.append(f"Print rate is too small for the given iteration step, clamping print rate to: {print_rate}")
|
||||
|
||||
if iterations < save_rate:
|
||||
save_rate = iterations
|
||||
messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {save_rate}")
|
||||
|
||||
return save_training_settings(iterations, batch_size, learning_rate, print_rate, save_rate, name, dataset_name, dataset_path, validation_name, validation_path, out_name )
|
||||
if resume_path and not os.path.exists(resume_path):
|
||||
messages.append("Resume path specified, but does not exist. Disabling...")
|
||||
resume_path = None
|
||||
|
||||
messages.append(save_training_settings(iterations,
|
||||
batch_size=batch_size,
|
||||
learning_rate=learning_rate,
|
||||
learning_rate_schedule=learning_rate_schedule,
|
||||
mega_batch_factor=mega_batch_factor,
|
||||
print_rate=print_rate,
|
||||
save_rate=save_rate,
|
||||
name=name,
|
||||
dataset_name=dataset_name,
|
||||
dataset_path=dataset_path,
|
||||
validation_name=validation_name,
|
||||
validation_path=validation_path,
|
||||
output_name=f"{voice}/train.yaml",
|
||||
resume_path=resume_path,
|
||||
))
|
||||
return "\n".join(messages)
|
||||
|
||||
def update_voices():
|
||||
return (
|
||||
|
@ -326,8 +357,11 @@ def setup_gradio():
|
|||
gr.Slider(label="Iterations", minimum=0, maximum=5000, value=500),
|
||||
gr.Slider(label="Batch Size", minimum=2, maximum=128, value=64),
|
||||
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
|
||||
gr.Textbox(label="Learning Rate Schedule", placeholder="[ 200, 300, 400, 500 ]"),
|
||||
gr.Slider(label="Mega Batch Factor", minimum=1, maximum=16, value=4),
|
||||
gr.Number(label="Print Frequency", value=50),
|
||||
gr.Number(label="Save Frequency", value=50),
|
||||
gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"),
|
||||
]
|
||||
dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" )
|
||||
training_settings = training_settings + [
|
||||
|
|
Loading…
Reference in New Issue
Block a user