From cf758f473262e4607ff18d3cb24976c86fcfb414 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 18 Feb 2023 15:50:51 +0000 Subject: [PATCH] oops --- models/.template.yaml | 8 ++++---- src/utils.py | 5 +++-- src/webui.py | 7 ++++--- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/models/.template.yaml b/models/.template.yaml index 64206f1..8c845ab 100755 --- a/models/.template.yaml +++ b/models/.template.yaml @@ -120,7 +120,7 @@ path: # afaik all units here are measured in **steps** (i.e. one batch of batch_size is 1 unit) train: # CHANGEME: ALL OF THESE PARAMETERS SHOULD BE EXPERIMENTED WITH - niter: 50000 + niter: ${iterations} warmup_iter: -1 mega_batch_factor: 4 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8]. val_freq: 500 @@ -139,8 +139,8 @@ eval: out: [gen, codebook_commitment_loss] logger: - print_freq: 100 - save_checkpoint_freq: 500 # CHANGEME: especially you should increase this it's really slow + print_freq: ${print_rate} + save_checkpoint_freq: ${save_rate} # CHANGEME: especially you should increase this it's really slow visuals: [gen, mel] - visual_debug_rate: 500 + visual_debug_rate: ${print_rate} is_mel_spectrogram: true \ No newline at end of file diff --git a/src/utils.py b/src/utils.py index 9339a82..3ce6da9 100755 --- a/src/utils.py +++ b/src/utils.py @@ -498,9 +498,10 @@ def setup_tortoise(restart=False): print("TorToiSe initialized, ready for generation.") return tts -def save_training_settings( batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None ): +def save_training_settings( iterations=None, batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None ): settings = { - "batch_size": batch_size if batch_size else 128, + "iterations": iterations if iterations else 500, + "batch_size": batch_size if batch_size else 64, "learning_rate": learning_rate if learning_rate else 1e-5, "print_rate": print_rate if print_rate else 50, "save_rate": save_rate if save_rate else 50, diff --git a/src/webui.py b/src/webui.py index 1186b1e..2cca424 100755 --- a/src/webui.py +++ b/src/webui.py @@ -201,7 +201,7 @@ def read_generate_settings_proxy(file, saveAs='.temp'): def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ): return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress ) -def save_training_settings_proxy( batch_size, learning_rate, print_rate, save_rate, voice ): +def save_training_settings_proxy( iterations, batch_size, learning_rate, print_rate, save_rate, voice ): name = f"{voice}-finetune" dataset_name = f"{voice}-train" dataset_path = f"./training/{voice}/train.txt" @@ -217,7 +217,7 @@ def save_training_settings_proxy( batch_size, learning_rate, print_rate, save_ra out_name = f"{voice}/train.yaml" - return save_training_settings(batch_size, learning_rate, print_rate, save_rate, name, dataset_name, dataset_path, validation_name, validation_path, out_name ) + return save_training_settings(iterations, batch_size, learning_rate, print_rate, save_rate, name, dataset_name, dataset_path, validation_name, validation_path, out_name ) def update_voices(): return ( @@ -346,7 +346,8 @@ def setup_gradio(): with gr.Row(): with gr.Column()