This commit is contained in:
mrq 2023-02-18 15:50:51 +00:00
parent 843bfbfb96
commit cf758f4732
3 changed files with 11 additions and 9 deletions

View File

@ -120,7 +120,7 @@ path:
# afaik all units here are measured in **steps** (i.e. one batch of batch_size is 1 unit)
train: # CHANGEME: ALL OF THESE PARAMETERS SHOULD BE EXPERIMENTED WITH
niter: 50000
niter: ${iterations}
warmup_iter: -1
mega_batch_factor: 4 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8].
val_freq: 500
@ -139,8 +139,8 @@ eval:
out: [gen, codebook_commitment_loss]
logger:
print_freq: 100
save_checkpoint_freq: 500 # CHANGEME: especially you should increase this it's really slow
print_freq: ${print_rate}
save_checkpoint_freq: ${save_rate} # CHANGEME: especially you should increase this it's really slow
visuals: [gen, mel]
visual_debug_rate: 500
visual_debug_rate: ${print_rate}
is_mel_spectrogram: true

View File

@ -498,9 +498,10 @@ def setup_tortoise(restart=False):
print("TorToiSe initialized, ready for generation.")
return tts
def save_training_settings( batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None ):
def save_training_settings( iterations=None, batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None ):
settings = {
"batch_size": batch_size if batch_size else 128,
"iterations": iterations if iterations else 500,
"batch_size": batch_size if batch_size else 64,
"learning_rate": learning_rate if learning_rate else 1e-5,
"print_rate": print_rate if print_rate else 50,
"save_rate": save_rate if save_rate else 50,

View File

@ -201,7 +201,7 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ):
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress )
def save_training_settings_proxy( batch_size, learning_rate, print_rate, save_rate, voice ):
def save_training_settings_proxy( iterations, batch_size, learning_rate, print_rate, save_rate, voice ):
name = f"{voice}-finetune"
dataset_name = f"{voice}-train"
dataset_path = f"./training/{voice}/train.txt"
@ -217,7 +217,7 @@ def save_training_settings_proxy( batch_size, learning_rate, print_rate, save_ra
out_name = f"{voice}/train.yaml"
return save_training_settings(batch_size, learning_rate, print_rate, save_rate, name, dataset_name, dataset_path, validation_name, validation_path, out_name )
return save_training_settings(iterations, batch_size, learning_rate, print_rate, save_rate, name, dataset_name, dataset_path, validation_name, validation_path, out_name )
def update_voices():
return (
@ -346,7 +346,8 @@ def setup_gradio():
with gr.Row():
with gr.Column():
training_settings = [
gr.Slider(label="Batch Size", value=128),
gr.Slider(label="Iterations", minimum=0, maximum=5000, value=500),
gr.Slider(label="Batch Size", minimum=2, maximum=128, value=64),
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
gr.Number(label="Print Frequency", value=50),
gr.Number(label="Save Frequency", value=50),