forked from mrq/ai-voice-cloning
set validation to save rate and validation file if exists (need to test later)
This commit is contained in:
parent
fe8bf7a9d1
commit
e862169e7f
|
@ -122,7 +122,7 @@ train:
|
|||
niter: ${iterations}
|
||||
warmup_iter: -1
|
||||
mega_batch_factor: ${gradient_accumulation_size}
|
||||
val_freq: ${iterations}
|
||||
val_freq: ${validation_rate}
|
||||
|
||||
ema_enabled: false # I really don't think EMA matters
|
||||
|
||||
|
|
|
@ -1228,10 +1228,7 @@ def schedule_learning_rate( iterations, schedule=EPOCH_SCHEDULE ):
|
|||
|
||||
def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, workers, source_model, voice ):
|
||||
name = f"{voice}-finetune"
|
||||
dataset_name = f"{voice}-train"
|
||||
dataset_path = f"./training/{voice}/train.txt"
|
||||
validation_name = f"{voice}-val"
|
||||
validation_path = f"./training/{voice}/train.txt"
|
||||
|
||||
with open(dataset_path, 'r', encoding="utf-8") as f:
|
||||
lines = len(f.readlines())
|
||||
|
@ -1303,7 +1300,7 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
|
|||
messages
|
||||
)
|
||||
|
||||
def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None, batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None, bnb=None, workers=None, source_model=None ):
|
||||
def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weight=None, learning_rate_schedule=None, batch_size=None, gradient_accumulation_size=None, print_rate=None, save_rate=None, validation_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None, resume_path=None, half_p=None, bnb=None, workers=None, source_model=None ):
|
||||
if not source_model:
|
||||
source_model = f"./models/tortoise/autoregressive{'_half' if half_p else ''}.pth"
|
||||
|
||||
|
@ -1320,6 +1317,7 @@ def save_training_settings( iterations=None, learning_rate=None, text_ce_lr_weig
|
|||
"dataset_path": dataset_path if dataset_path else "./training/finetune/train.txt",
|
||||
"validation_name": validation_name if validation_name else "finetune",
|
||||
"validation_path": validation_path if validation_path else "./training/finetune/train.txt",
|
||||
'validation_rate': validation_rate if validation_rate else iterations,
|
||||
|
||||
"text_ce_lr_weight": text_ce_lr_weight if text_ce_lr_weight else 0.01,
|
||||
|
||||
|
|
17
src/webui.py
17
src/webui.py
|
@ -300,7 +300,7 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
|
|||
dataset_name = f"{voice}-train"
|
||||
dataset_path = f"./training/{voice}/train.txt"
|
||||
validation_name = f"{voice}-val"
|
||||
validation_path = f"./training/{voice}/train.txt"
|
||||
validation_path = f"./training/{voice}/validation.txt"
|
||||
|
||||
with open(dataset_path, 'r', encoding="utf-8") as f:
|
||||
lines = len(f.readlines())
|
||||
|
@ -312,6 +312,16 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
|
|||
|
||||
print_rate = int(print_rate * iterations / epochs)
|
||||
save_rate = int(save_rate * iterations / epochs)
|
||||
validation_rate = save_rate
|
||||
|
||||
if iterations % save_rate != 0:
|
||||
adjustment = int(iterations / save_rate) * save_rate
|
||||
messages.append(f"Iteration rate is not evenly divisible by save rate, adjusting: {iterations} => {adjustment}")
|
||||
iterations = adjustment
|
||||
|
||||
if not os.path.exists(validation_path):
|
||||
validation_rate = iterations
|
||||
validation_path = dataset_path
|
||||
|
||||
if not learning_rate_schedule:
|
||||
learning_rate_schedule = EPOCH_SCHEDULE
|
||||
|
@ -329,6 +339,7 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
|
|||
gradient_accumulation_size=gradient_accumulation_size,
|
||||
print_rate=print_rate,
|
||||
save_rate=save_rate,
|
||||
validation_rate=validation_rate,
|
||||
name=name,
|
||||
dataset_name=dataset_name,
|
||||
dataset_path=dataset_path,
|
||||
|
@ -559,7 +570,7 @@ def setup_gradio():
|
|||
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
||||
|
||||
with gr.Row():
|
||||
training_keep_x_past_checkpoints = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
||||
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
||||
training_gpu_count = gr.Number(label="GPUs", value=get_device_count())
|
||||
with gr.Row():
|
||||
start_training_button = gr.Button(value="Train")
|
||||
|
@ -777,7 +788,7 @@ def setup_gradio():
|
|||
training_configs,
|
||||
verbose_training,
|
||||
training_gpu_count,
|
||||
training_keep_x_past_checkpoints,
|
||||
training_keep_x_past_datasets,
|
||||
],
|
||||
outputs=[
|
||||
training_output,
|
||||
|
|
Loading…
Reference in New Issue
Block a user