From 092dd7b2d78b89abc0f1855aeb1f3bee83d3eb7f Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 19 Feb 2023 16:16:44 +0000 Subject: [PATCH] added more safeties and parameters to training yaml generator, I think I tested it extensively enough --- models/.template.yaml | 10 ++--- src/utils.py | 12 +++++- src/webui.py | 86 ++++++++++++++++++++++++++++++------------- 3 files changed, 75 insertions(+), 33 deletions(-) diff --git a/models/.template.yaml b/models/.template.yaml index 8c845ab..dc4b1ff 100755 --- a/models/.template.yaml +++ b/models/.template.yaml @@ -114,19 +114,19 @@ networks: #only_alignment_head: False # uv3/4 path: - pretrain_model_gpt: './models/tortoise/autoregressive.pth' # CHANGEME: copy this from tortoise cache + ${pretrain_model_gpt} strict_load: true - #resume_state: ./models/tortoise/train_imgnet_vqvae_stage1/training_state/0.state # <-- Set this to resume from a previous training state. + ${resume_state} # afaik all units here are measured in **steps** (i.e. one batch of batch_size is 1 unit) train: # CHANGEME: ALL OF THESE PARAMETERS SHOULD BE EXPERIMENTED WITH niter: ${iterations} warmup_iter: -1 - mega_batch_factor: 4 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8]. - val_freq: 500 + mega_batch_factor: ${mega_batch_factor} # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8]. + val_freq: ${iterations} default_lr_scheme: MultiStepLR - gen_lr_steps: [500, 1000, 1400, 1800] #[50000, 100000, 140000, 180000] + gen_lr_steps: ${gen_lr_steps} #[50000, 100000, 140000, 180000] lr_gamma: 0.5 eval: diff --git a/src/utils.py b/src/utils.py index be0e304..48ea2b8 100755 --- a/src/utils.py +++ b/src/utils.py @@ -580,11 +580,13 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm return voice -def save_training_settings( iterations=None, batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None ): +def save_training_settings( iterations=None, batch_size=None, learning_rate=None, learning_rate_schedule=None, mega_batch_factor=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None ): settings = { "iterations": iterations if iterations else 500, "batch_size": batch_size if batch_size else 64, "learning_rate": learning_rate if learning_rate else 1e-5, + "gen_lr_steps": learning_rate_schedule if learning_rate_schedule else [ 200, 300, 400, 500 ], + "mega_batch_factor": mega_batch_factor if mega_batch_factor else 4, "print_rate": print_rate if print_rate else 50, "save_rate": save_rate if save_rate else 50, "name": name if name else "finetune", @@ -592,19 +594,25 @@ def save_training_settings( iterations=None, batch_size=None, learning_rate=None "dataset_path": dataset_path if dataset_path else "./training/finetune/train.txt", "validation_name": validation_name if validation_name else "finetune", "validation_path": validation_path if validation_path else "./training/finetune/train.txt", + + 'resume_state': f"resume_state: '{resume_path}'" if resume_path else f"# resume_state: './training/{name if name else 'finetune'}/training_state/#.state'", + 'pretrain_model_gpt': "pretrain_model_gpt: './models/tortoise/autoregressive.pth'" if not resume_path else "# pretrain_model_gpt: './models/tortoise/autoregressive.pth'" } if not output_name: output_name = f'{settings["name"]}.yaml' - outfile = f'./training/{output_name}' with open(f'./models/.template.yaml', 'r', encoding="utf-8") as f: yaml = f.read() + # i could just load and edit the YAML directly, but this is easier, as I don't need to bother with path traversals for k in settings: + if settings[k] is None: + continue yaml = yaml.replace(f"${{{k}}}", str(settings[k])) + outfile = f'./training/{output_name}' with open(outfile, 'w', encoding="utf-8") as f: f.write(yaml) diff --git a/src/webui.py b/src/webui.py index ba9c7a3..6cb48cc 100755 --- a/src/webui.py +++ b/src/webui.py @@ -47,28 +47,28 @@ def run_generation( ): try: sample, outputs, stats = generate( - text, - delimiter, - emotion, - prompt, - voice, - mic_audio, - voice_latents_chunks, - seed, - candidates, - num_autoregressive_samples, - diffusion_iterations, - temperature, - diffusion_sampler, - breathing_room, - cvvp_weight, - top_p, - diffusion_temperature, - length_penalty, - repetition_penalty, - cond_free_k, - experimental_checkboxes, - progress + text=text, + delimiter=delimiter, + emotion=emotion, + prompt=prompt, + voice=voice, + mic_audio=mic_audio, + voice_latents_chunks=voice_latents_chunks, + seed=seed, + candidates=candidates, + num_autoregressive_samples=num_autoregressive_samples, + diffusion_iterations=diffusion_iterations, + temperature=temperature, + diffusion_sampler=diffusion_sampler, + breathing_room=breathing_room, + cvvp_weight=cvvp_weight, + top_p=top_p, + diffusion_temperature=diffusion_temperature, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + cond_free_k=cond_free_k, + experimental_checkboxes=experimental_checkboxes, + progress=progress ) except Exception as e: message = str(e) @@ -180,7 +180,7 @@ def read_generate_settings_proxy(file, saveAs='.temp'): def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ): return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress ) -def save_training_settings_proxy( iterations, batch_size, learning_rate, print_rate, save_rate, voice ): +def save_training_settings_proxy( iterations, batch_size, learning_rate, learning_rate_schedule, mega_batch_factor, print_rate, save_rate, resume_path, voice ): name = f"{voice}-finetune" dataset_name = f"{voice}-train" dataset_path = f"./training/{voice}/train.txt" @@ -190,13 +190,44 @@ def save_training_settings_proxy( iterations, batch_size, learning_rate, print_r with open(dataset_path, 'r', encoding="utf-8") as f: lines = len(f.readlines()) + messages = [] + if batch_size > lines: - print("Batch size is larger than your dataset, clamping...") batch_size = lines + messages.append(f"Batch size is larger than your dataset, clamping batch size to: {batch_size}") + + if batch_size / mega_batch_factor < 2: + mega_batch_factor = int(batch_size / 2) + messages.append(f"Mega batch factor is too large for the given batch size, clamping mega batch factor to: {mega_batch_factor}") - out_name = f"{voice}/train.yaml" + if iterations < print_rate: + print_rate = iterations + messages.append(f"Print rate is too small for the given iteration step, clamping print rate to: {print_rate}") + + if iterations < save_rate: + save_rate = iterations + messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {save_rate}") - return save_training_settings(iterations, batch_size, learning_rate, print_rate, save_rate, name, dataset_name, dataset_path, validation_name, validation_path, out_name ) + if resume_path and not os.path.exists(resume_path): + messages.append("Resume path specified, but does not exist. Disabling...") + resume_path = None + + messages.append(save_training_settings(iterations, + batch_size=batch_size, + learning_rate=learning_rate, + learning_rate_schedule=learning_rate_schedule, + mega_batch_factor=mega_batch_factor, + print_rate=print_rate, + save_rate=save_rate, + name=name, + dataset_name=dataset_name, + dataset_path=dataset_path, + validation_name=validation_name, + validation_path=validation_path, + output_name=f"{voice}/train.yaml", + resume_path=resume_path, + )) + return "\n".join(messages) def update_voices(): return ( @@ -326,8 +357,11 @@ def setup_gradio(): gr.Slider(label="Iterations", minimum=0, maximum=5000, value=500), gr.Slider(label="Batch Size", minimum=2, maximum=128, value=64), gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6), + gr.Textbox(label="Learning Rate Schedule", placeholder="[ 200, 300, 400, 500 ]"), + gr.Slider(label="Mega Batch Factor", minimum=1, maximum=16, value=4), gr.Number(label="Print Frequency", value=50), gr.Number(label="Save Frequency", value=50), + gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"), ] dataset_list = gr.Dropdown( get_dataset_list(), label="Dataset", type="value" ) training_settings = training_settings + [