This commit is contained in:
mrq 2023-02-18 15:50:51 +00:00
parent 843bfbfb96
commit cf758f4732
3 changed files with 11 additions and 9 deletions

View File

@ -120,7 +120,7 @@ path:
# afaik all units here are measured in **steps** (i.e. one batch of batch_size is 1 unit) # afaik all units here are measured in **steps** (i.e. one batch of batch_size is 1 unit)
train: # CHANGEME: ALL OF THESE PARAMETERS SHOULD BE EXPERIMENTED WITH train: # CHANGEME: ALL OF THESE PARAMETERS SHOULD BE EXPERIMENTED WITH
niter: 50000 niter: ${iterations}
warmup_iter: -1 warmup_iter: -1
mega_batch_factor: 4 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8]. mega_batch_factor: 4 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8].
val_freq: 500 val_freq: 500
@ -139,8 +139,8 @@ eval:
out: [gen, codebook_commitment_loss] out: [gen, codebook_commitment_loss]
logger: logger:
print_freq: 100 print_freq: ${print_rate}
save_checkpoint_freq: 500 # CHANGEME: especially you should increase this it's really slow save_checkpoint_freq: ${save_rate} # CHANGEME: especially you should increase this it's really slow
visuals: [gen, mel] visuals: [gen, mel]
visual_debug_rate: 500 visual_debug_rate: ${print_rate}
is_mel_spectrogram: true is_mel_spectrogram: true

View File

@ -498,9 +498,10 @@ def setup_tortoise(restart=False):
print("TorToiSe initialized, ready for generation.") print("TorToiSe initialized, ready for generation.")
return tts return tts
def save_training_settings( batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None ): def save_training_settings( iterations=None, batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None ):
settings = { settings = {
"batch_size": batch_size if batch_size else 128, "iterations": iterations if iterations else 500,
"batch_size": batch_size if batch_size else 64,
"learning_rate": learning_rate if learning_rate else 1e-5, "learning_rate": learning_rate if learning_rate else 1e-5,
"print_rate": print_rate if print_rate else 50, "print_rate": print_rate if print_rate else 50,
"save_rate": save_rate if save_rate else 50, "save_rate": save_rate if save_rate else 50,

View File

@ -201,7 +201,7 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ): def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ):
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress ) return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress )
def save_training_settings_proxy( batch_size, learning_rate, print_rate, save_rate, voice ): def save_training_settings_proxy( iterations, batch_size, learning_rate, print_rate, save_rate, voice ):
name = f"{voice}-finetune" name = f"{voice}-finetune"
dataset_name = f"{voice}-train" dataset_name = f"{voice}-train"
dataset_path = f"./training/{voice}/train.txt" dataset_path = f"./training/{voice}/train.txt"
@ -217,7 +217,7 @@ def save_training_settings_proxy( batch_size, learning_rate, print_rate, save_ra
out_name = f"{voice}/train.yaml" out_name = f"{voice}/train.yaml"
return save_training_settings(batch_size, learning_rate, print_rate, save_rate, name, dataset_name, dataset_path, validation_name, validation_path, out_name ) return save_training_settings(iterations, batch_size, learning_rate, print_rate, save_rate, name, dataset_name, dataset_path, validation_name, validation_path, out_name )
def update_voices(): def update_voices():
return ( return (
@ -346,7 +346,8 @@ def setup_gradio():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
training_settings = [ training_settings = [
gr.Slider(label="Batch Size", value=128), gr.Slider(label="Iterations", minimum=0, maximum=5000, value=500),
gr.Slider(label="Batch Size", minimum=2, maximum=128, value=64),
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6), gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
gr.Number(label="Print Frequency", value=50), gr.Number(label="Print Frequency", value=50),
gr.Number(label="Save Frequency", value=50), gr.Number(label="Save Frequency", value=50),