VALL-E config edits

This commit is contained in:
mrq 2023-03-20 01:22:53 +00:00
parent 2e33bf071a
commit 34ef0467b9
3 changed files with 62 additions and 39 deletions

View File

@ -1,41 +1,61 @@
{ {
"autotuning": { "optimizer": {
"enabled": false, "type": "AdamW",
"results_dir": "./config/autotune/results", "params": {
"exps_dir": "./config/autotune/exps", "lr": 2e-05,
"overwrite": false, "betas": [
"metric": "throughput", 0.9,
"start_profile_step": 10, 0.96
"end_profile_step": 20, ],
"fast": false, "eps": 1e-07,
"max_train_batch_size": 32, "weight_decay": 0.01
"mp_size": 1, }
"num_tuning_micro_batch_sizes": 3, },
"tuner_type": "model_based", "scheduler":{
"tuner_early_stopping": 5, "type":"WarmupLR",
"tuner_num_trials": 50, "params":{
"arg_mappings": { "warmup_min_lr":0,
"train_micro_batch_size_per_gpu": "--per_device_train_batch_size", "warmup_max_lr":2e-5,
"gradient_accumulation_steps ": "--gradient_accumulation_steps" "warmup_num_steps":100,
"warmup_type":"linear"
} }
}, },
"zero_optimization": { "fp16":{
"stage": 0, "enabled":true,
"offload_param": { "loss_scale":0,
"device": "nvme", "loss_scale_window":1000,
"nvme_path": "/tmp/zero/", "initial_scale_power":16,
"pin_memory": false, "hysteresis":2,
"buffer_count": 5, "min_loss_scale":1
"buffer_size": 1e9, },
"max_in_cpu": 1e9 "autotuning":{
}, "enabled":false,
"overlap_comm": true, "results_dir":"./config/autotune/results",
"reduce_bucket_size": "auto", "exps_dir":"./config/autotune/exps",
"contiguous_gradients": true, "overwrite":false,
"sub_group_size": 1e8, "metric":"throughput",
"stage3_prefetch_bucket_size": "auto", "start_profile_step":10,
"stage3_param_persistence_threshold": "auto", "end_profile_step":20,
"stage3_max_live_parameters": "auto", "fast":false,
"stage3_max_reuse_distance": "auto" "max_train_batch_size":32,
"mp_size":1,
"num_tuning_micro_batch_sizes":3,
"tuner_type":"model_based",
"tuner_early_stopping":5,
"tuner_num_trials":50,
"arg_mappings":{
"train_micro_batch_size_per_gpu":"--per_device_train_batch_size",
"gradient_accumulation_steps ":"--gradient_accumulation_steps"
}
},
"zero_optimization":{
"stage":0,
"reduce_bucket_size":"auto",
"contiguous_gradients":true,
"sub_group_size":1e8,
"stage3_prefetch_bucket_size":"auto",
"stage3_param_persistence_threshold":"auto",
"stage3_max_live_parameters":"auto",
"stage3_max_reuse_distance":"auto"
} }
} }

View File

@ -3,14 +3,17 @@ ckpt_root: ./training/${voice}/finetune/ckpt/
log_root: ./training/${voice}/finetune/logs/ log_root: ./training/${voice}/finetune/logs/
data_dirs: [./training/${voice}/valle/] data_dirs: [./training/${voice}/valle/]
spkr_name_getter: "lambda p: p.parts[-3]" spkr_name_getter: "lambda p: p.parts[-3]" # "lambda p: p.parts[-1].split('-')[0]"
model: ${model_name} model: ${model_name}
batch_size: ${batch_size} batch_size: ${batch_size}
eval_batch_size: ${validation_batch_size} gradient_accumulation_steps: ${gradient_accumulation_size}
eval_batch_size: ${batch_size}
max_iter: ${iterations} max_iter: ${iterations}
save_ckpt_every: ${save_rate} save_ckpt_every: ${save_rate}
eval_every: ${validation_rate} eval_every: ${validation_rate}
max_phones: 256
sampling_temperature: 1.0 sampling_temperature: 1.0

View File

@ -488,7 +488,7 @@ def setup_gradio():
) )
with gr.Row(): with gr.Row():
TRAINING_SETTINGS["batch_size"] = gr.Number(label="Batch Size", value=128, precision=0) TRAINING_SETTINGS["batch_size"] = gr.Number(label="Batch Size", value=128, precision=0)
TRAINING_SETTINGS["gradient_accumulation_size"] = gr.Number(label="Gradient Accumulation Size", value=4, precision=0, visible=args.tts_backend=="tortoise") TRAINING_SETTINGS["gradient_accumulation_size"] = gr.Number(label="Gradient Accumulation Size", value=4, precision=0)
with gr.Row(): with gr.Row():
TRAINING_SETTINGS["save_rate"] = gr.Number(label="Save Frequency (in epochs)", value=5, precision=0) TRAINING_SETTINGS["save_rate"] = gr.Number(label="Save Frequency (in epochs)", value=5, precision=0)
TRAINING_SETTINGS["validation_rate"] = gr.Number(label="Validation Frequency (in epochs)", value=5, precision=0) TRAINING_SETTINGS["validation_rate"] = gr.Number(label="Validation Frequency (in epochs)", value=5, precision=0)