forked from mrq/ai-voice-cloning
oops
This commit is contained in:
parent
843bfbfb96
commit
cf758f4732
|
@ -120,7 +120,7 @@ path:
|
||||||
|
|
||||||
# afaik all units here are measured in **steps** (i.e. one batch of batch_size is 1 unit)
|
# afaik all units here are measured in **steps** (i.e. one batch of batch_size is 1 unit)
|
||||||
train: # CHANGEME: ALL OF THESE PARAMETERS SHOULD BE EXPERIMENTED WITH
|
train: # CHANGEME: ALL OF THESE PARAMETERS SHOULD BE EXPERIMENTED WITH
|
||||||
niter: 50000
|
niter: ${iterations}
|
||||||
warmup_iter: -1
|
warmup_iter: -1
|
||||||
mega_batch_factor: 4 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8].
|
mega_batch_factor: 4 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8].
|
||||||
val_freq: 500
|
val_freq: 500
|
||||||
|
@ -139,8 +139,8 @@ eval:
|
||||||
out: [gen, codebook_commitment_loss]
|
out: [gen, codebook_commitment_loss]
|
||||||
|
|
||||||
logger:
|
logger:
|
||||||
print_freq: 100
|
print_freq: ${print_rate}
|
||||||
save_checkpoint_freq: 500 # CHANGEME: especially you should increase this it's really slow
|
save_checkpoint_freq: ${save_rate} # CHANGEME: especially you should increase this it's really slow
|
||||||
visuals: [gen, mel]
|
visuals: [gen, mel]
|
||||||
visual_debug_rate: 500
|
visual_debug_rate: ${print_rate}
|
||||||
is_mel_spectrogram: true
|
is_mel_spectrogram: true
|
|
@ -498,9 +498,10 @@ def setup_tortoise(restart=False):
|
||||||
print("TorToiSe initialized, ready for generation.")
|
print("TorToiSe initialized, ready for generation.")
|
||||||
return tts
|
return tts
|
||||||
|
|
||||||
def save_training_settings( batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None ):
|
def save_training_settings( iterations=None, batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None ):
|
||||||
settings = {
|
settings = {
|
||||||
"batch_size": batch_size if batch_size else 128,
|
"iterations": iterations if iterations else 500,
|
||||||
|
"batch_size": batch_size if batch_size else 64,
|
||||||
"learning_rate": learning_rate if learning_rate else 1e-5,
|
"learning_rate": learning_rate if learning_rate else 1e-5,
|
||||||
"print_rate": print_rate if print_rate else 50,
|
"print_rate": print_rate if print_rate else 50,
|
||||||
"save_rate": save_rate if save_rate else 50,
|
"save_rate": save_rate if save_rate else 50,
|
||||||
|
|
|
@ -201,7 +201,7 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
|
||||||
def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ):
|
def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ):
|
||||||
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress )
|
return prepare_dataset( get_voices(load_latents=False)[voice], outdir=f"./training/{voice}/", language=language, progress=progress )
|
||||||
|
|
||||||
def save_training_settings_proxy( batch_size, learning_rate, print_rate, save_rate, voice ):
|
def save_training_settings_proxy( iterations, batch_size, learning_rate, print_rate, save_rate, voice ):
|
||||||
name = f"{voice}-finetune"
|
name = f"{voice}-finetune"
|
||||||
dataset_name = f"{voice}-train"
|
dataset_name = f"{voice}-train"
|
||||||
dataset_path = f"./training/{voice}/train.txt"
|
dataset_path = f"./training/{voice}/train.txt"
|
||||||
|
@ -217,7 +217,7 @@ def save_training_settings_proxy( batch_size, learning_rate, print_rate, save_ra
|
||||||
|
|
||||||
out_name = f"{voice}/train.yaml"
|
out_name = f"{voice}/train.yaml"
|
||||||
|
|
||||||
return save_training_settings(batch_size, learning_rate, print_rate, save_rate, name, dataset_name, dataset_path, validation_name, validation_path, out_name )
|
return save_training_settings(iterations, batch_size, learning_rate, print_rate, save_rate, name, dataset_name, dataset_path, validation_name, validation_path, out_name )
|
||||||
|
|
||||||
def update_voices():
|
def update_voices():
|
||||||
return (
|
return (
|
||||||
|
@ -346,7 +346,8 @@ def setup_gradio():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
training_settings = [
|
training_settings = [
|
||||||
gr.Slider(label="Batch Size", value=128),
|
gr.Slider(label="Iterations", minimum=0, maximum=5000, value=500),
|
||||||
|
gr.Slider(label="Batch Size", minimum=2, maximum=128, value=64),
|
||||||
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
|
gr.Slider(label="Learning Rate", value=1e-5, minimum=0, maximum=1e-4, step=1e-6),
|
||||||
gr.Number(label="Print Frequency", value=50),
|
gr.Number(label="Print Frequency", value=50),
|
||||||
gr.Number(label="Save Frequency", value=50),
|
gr.Number(label="Save Frequency", value=50),
|
||||||
|
|
Loading…
Reference in New Issue
Block a user