oops
This commit is contained in:
parent
843bfbfb96
commit
cf758f4732
|
@ -120,7 +120,7 @@ path:
|
|||
|
||||
# afaik all units here are measured in **steps** (i.e. one batch of batch_size is 1 unit)
|
||||
train: # CHANGEME: ALL OF THESE PARAMETERS SHOULD BE EXPERIMENTED WITH
|
||||
niter: 50000
|
||||
niter: ${iterations}
|
||||
warmup_iter: -1
|
||||
mega_batch_factor: 4 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8].
|
||||
val_freq: 500
|
||||
|
@ -139,8 +139,8 @@ eval:
|
|||
out: [gen, codebook_commitment_loss]
|
||||
|
||||
logger:
|
||||
print_freq: 100
|
||||
save_checkpoint_freq: 500 # CHANGEME: especially you should increase this it's really slow
|
||||
print_freq: ${print_rate}
|
||||
save_checkpoint_freq: ${save_rate} # CHANGEME: especially you should increase this it's really slow
|
||||
visuals: [gen, mel]
|
||||
visual_debug_rate: 500
|
||||
visual_debug_rate: ${print_rate}
|
||||
is_mel_spectrogram: true
|
|
@ -498,9 +498,10 @@ def setup_tortoise(restart=False):
|
|||
print("TorToiSe initialized, ready for generation.")
|
||||
return tts
|
||||
|
||||
def save_training_settings( batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None ):
|
||||
def save_training_settings( iterations=None, batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None ):
|
||||
settings = {
|
||||
"batch_size": batch_size if batch_size else 128,
|
||||
"iterations": iterations if iterations else 500,
|
||||
"batch_size": batch_size if batch_size else 64,
|
||||
"learning_rate": learning_rate if learning_rate else 1e-5,
|
||||
"print_rate": print_rate if print_rate else 50,
|
||||
"save_rate": save_rate if save_rate else 50,
|
||||
|
|
|
@ -201,7 +201,7 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
|
|||
def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ):
|
||||
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress )
|
||||
|
||||
def save_training_settings_proxy( batch_size, learning_rate, print_rate, save_rate, voice ):
|
||||
def save_training_settings_proxy( iterations, batch_size, learning_rate, print_rate, save_rate, voice ):
|
||||
name = f"{voice}-finetune"
|
||||
dataset_name = f"{voice}-train"
|
||||
dataset_path = f"./training/{voice}/train.txt"
|
||||
|
@ -217,7 +217,7 @@ def save_training_settings_proxy( batch_size, learning_rate, print_rate, save_ra
|
|||
|
||||
out_name = f"{voice}/train.yaml"
|
||||
|
||||
return save_training_settings(batch_size, learning_rate, print_rate, save_rate, name, dataset_name, dataset_path, validation_name, validation_path, out_name )
|
||||
return save_training_settings(iterations, batch_size, learning_rate, print_rate, save_rate, name, dataset_name, dataset_path, validation_name, validation_path, out_name )
|
||||
|
||||
def update_voices():
|
||||
return (
|
||||
|
@ -346,7 +346,8 @@ def setup_gradio():
|
|||
with gr.Row():
|
||||
with gr.Column():
|
||||
training_settings = [
|
||||
gr.Slider(label="Batch Size", value=128),
|
||||
gr.Slider(label="Iterations", minimum=0, maximum=5000, value=500),
|
||||
gr.Slider(label="Batch Size", minimum=2, maximum=128, value=64),
|
||||
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
|
||||
gr.Number(label="Print Frequency", value=50),
|
||||
gr.Number(label="Save Frequency", value=50),
|
||||
|
|
Loading…
Reference in New Issue
Block a user