forked from mrq/ai-voice-cloning
set validation to save rate and validation file if exists (need to test later)
This commit is contained in:
parent
fe8bf7a9d1
commit
e862169e7f
|
@ -122,7 +122,7 @@ train:
|
||||||
niter: ${iterations}
|
niter: ${iterations}
|
||||||
warmup_iter: -1
|
warmup_iter: -1
|
||||||
mega_batch_factor: ${gradient_accumulation_size}
|
mega_batch_factor: ${gradient_accumulation_size}
|
||||||
val_freq: ${iterations}
|
val_freq: ${validation_rate}
|
||||||
|
|
||||||
ema_enabled: false # I really don't think EMA matters
|
ema_enabled: false # I really don't think EMA matters
|
||||||
|
|
||||||
|
|
|
@ -1228,10 +1228,7 @@ def schedule_learning_rate( iterations, schedule=EPOCH_SCHEDULE ):
|
||||||
|
|
||||||
def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, workers, source_model, voice ):
|
def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, workers, source_model, voice ):
|
||||||
name = f"{voice}-finetune"
|
name = f"{voice}-finetune"
|
||||||
dataset_name = f"{voice}-train"
|
|
||||||
dataset_path = f"./training/{voice}/train.txt"
|
dataset_path = f"./training/{voice}/train.txt"
|
||||||
validation_name = f"{voice}-val"
|
|
||||||
validation_path = f"./training/{voice}/train.txt"
|
|
||||||
|
|
||||||
with open(dataset_path, 'r', encoding="utf-8") as f:
|
with open(dataset_path, 'r', encoding="utf-8") as f:
|
||||||
lines = len(f.readlines())
|
lines = len(f.readlines())
|
||||||
|
@ -1303,7 +1300,7 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
|
||||||
messages
|
messages
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None, batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None, bnb=None, workers=None, source_model=None ):
|
def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None, batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, validation_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None, bnb=None, workers=None, source_model=None ):
|
||||||
if not source_model:
|
if not source_model:
|
||||||
source_model = f"./models/tortoise/autoregressive{'_half' if half_p else ''}.pth"
|
source_model = f"./models/tortoise/autoregressive{'_half' if half_p else ''}.pth"
|
||||||
|
|
||||||
|
@ -1320,6 +1317,7 @@ def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weig
|
||||||
"dataset_path": dataset_path if dataset_path else "./training/finetune/train.txt",
|
"dataset_path": dataset_path if dataset_path else "./training/finetune/train.txt",
|
||||||
"validation_name": validation_name if validation_name else "finetune",
|
"validation_name": validation_name if validation_name else "finetune",
|
||||||
"validation_path": validation_path if validation_path else "./training/finetune/train.txt",
|
"validation_path": validation_path if validation_path else "./training/finetune/train.txt",
|
||||||
|
'validation_rate': validation_rate if validation_rate else iterations,
|
||||||
|
|
||||||
"text_ce_lr_weight": text_ce_lr_weight if text_ce_lr_weight else 0.01,
|
"text_ce_lr_weight": text_ce_lr_weight if text_ce_lr_weight else 0.01,
|
||||||
|
|
||||||
|
|
17
src/webui.py
17
src/webui.py
|
@ -300,7 +300,7 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
|
||||||
dataset_name = f"{voice}-train"
|
dataset_name = f"{voice}-train"
|
||||||
dataset_path = f"./training/{voice}/train.txt"
|
dataset_path = f"./training/{voice}/train.txt"
|
||||||
validation_name = f"{voice}-val"
|
validation_name = f"{voice}-val"
|
||||||
validation_path = f"./training/{voice}/train.txt"
|
validation_path = f"./training/{voice}/validation.txt"
|
||||||
|
|
||||||
with open(dataset_path, 'r', encoding="utf-8") as f:
|
with open(dataset_path, 'r', encoding="utf-8") as f:
|
||||||
lines = len(f.readlines())
|
lines = len(f.readlines())
|
||||||
|
@ -312,6 +312,16 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
|
||||||
|
|
||||||
print_rate = int(print_rate * iterations / epochs)
|
print_rate = int(print_rate * iterations / epochs)
|
||||||
save_rate = int(save_rate * iterations / epochs)
|
save_rate = int(save_rate * iterations / epochs)
|
||||||
|
validation_rate = save_rate
|
||||||
|
|
||||||
|
if iterations % save_rate != 0:
|
||||||
|
adjustment = int(iterations / save_rate) * save_rate
|
||||||
|
messages.append(f"Iteration rate is not evenly divisible by save rate, adjusting: {iterations} => {adjustment}")
|
||||||
|
iterations = adjustment
|
||||||
|
|
||||||
|
if not os.path.exists(validation_path):
|
||||||
|
validation_rate = iterations
|
||||||
|
validation_path = dataset_path
|
||||||
|
|
||||||
if not learning_rate_schedule:
|
if not learning_rate_schedule:
|
||||||
learning_rate_schedule = EPOCH_SCHEDULE
|
learning_rate_schedule = EPOCH_SCHEDULE
|
||||||
|
@ -329,6 +339,7 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
|
||||||
gradient_accumulation_size=gradient_accumulation_size,
|
gradient_accumulation_size=gradient_accumulation_size,
|
||||||
print_rate=print_rate,
|
print_rate=print_rate,
|
||||||
save_rate=save_rate,
|
save_rate=save_rate,
|
||||||
|
validation_rate=validation_rate,
|
||||||
name=name,
|
name=name,
|
||||||
dataset_name=dataset_name,
|
dataset_name=dataset_name,
|
||||||
dataset_path=dataset_path,
|
dataset_path=dataset_path,
|
||||||
|
@ -559,7 +570,7 @@ def setup_gradio():
|
||||||
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
training_keep_x_past_checkpoints = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
||||||
training_gpu_count = gr.Number(label="GPUs", value=get_device_count())
|
training_gpu_count = gr.Number(label="GPUs", value=get_device_count())
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
start_training_button = gr.Button(value="Train")
|
start_training_button = gr.Button(value="Train")
|
||||||
|
@ -777,7 +788,7 @@ def setup_gradio():
|
||||||
training_configs,
|
training_configs,
|
||||||
verbose_training,
|
verbose_training,
|
||||||
training_gpu_count,
|
training_gpu_count,
|
||||||
training_keep_x_past_checkpoints,
|
training_keep_x_past_datasets,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
training_output,
|
training_output,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user