diff --git a/models/.template.yaml b/models/.template.yaml index af555e2..28fbd2c 100755 --- a/models/.template.yaml +++ b/models/.template.yaml @@ -28,8 +28,8 @@ datasets: load_aligned_codes: False val: # I really do not care about validation right now name: ${validation_name} - n_workers: 1 - batch_size: 1 + n_workers: ${workers} + batch_size: ${batch_size} mode: paired_voice_audio path: ${validation_path} fetcher_mode: ['lj'] @@ -131,13 +131,8 @@ train: lr_gamma: 0.5 eval: + pure: True output_state: gen - injectors: - gen_inj_eval: - type: generator - generator: generator - in: hq - out: [gen, codebook_commitment_loss] logger: print_freq: ${print_rate} diff --git a/src/utils.py b/src/utils.py index aa10143..9752ca0 100755 --- a/src/utils.py +++ b/src/utils.py @@ -684,7 +684,7 @@ class TrainingState(): except Exception as e: use_tensorboard = False - keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total'] + keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total', 'val_loss_text_ce', 'val_loss_mel_ce'] infos = {} highest_step = self.last_info_check_at @@ -1220,6 +1220,44 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres return f"Processed dataset to: {outdir}\n{joined}" +def prepare_validation_dataset( voice, text_length ): + indir = f'./training/{voice}/' + infile = f'{indir}/dataset.txt' + if not os.path.exists(infile): + infile = f'{indir}/train.txt' + with open(f'{indir}/train.txt', 'r', encoding="utf-8") as src: + with open(f'{indir}/dataset.txt', 'w', encoding="utf-8") as dst: + dst.write(src.read()) + + if not os.path.exists(infile): + raise Exception(f"Missing dataset: {infile}") + + with open(infile, 'r', encoding="utf-8") as f: + lines = f.readlines() + + validation = [] + training = [] + + for line in lines: + split = line.split("|") + filename = split[0] + text = split[1] + + if len(text) < text_length: + validation.append(line.strip()) + else: + training.append(line.strip()) + + with open(f'{indir}/train.txt', 'w', encoding="utf-8") as f: + f.write("\n".join(training)) + + with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f: + f.write("\n".join(validation)) + + msg = f"Culled {len(validation)} lines" + print(msg) + return msg + def calc_iterations( epochs, lines, batch_size ): iterations = int(epochs * lines / float(batch_size)) return iterations @@ -1227,7 +1265,7 @@ def calc_iterations( epochs, lines, batch_size ): def schedule_learning_rate( iterations, schedule=EPOCH_SCHEDULE ): return [int(iterations * d) for d in schedule] -def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, workers, source_model, voice ): +def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, validation_rate, resume_path, half_p, bnb, workers, source_model, voice ): name = f"{voice}-finetune" dataset_path = f"./training/{voice}/train.txt" @@ -1271,6 +1309,10 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni save_rate = epochs messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {save_rate}") + if epochs < validation_rate: + validation_rate = epochs + messages.append(f"Validation rate is too small for the given iteration step, clamping validation rate to: {validation_rate}") + if resume_path and not os.path.exists(resume_path): resume_path = None messages.append("Resume path specified, but does not exist. Disabling...") @@ -1297,6 +1339,7 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni gradient_accumulation_size, print_rate, save_rate, + validation_rate, resume_path, messages ) diff --git a/src/webui.py b/src/webui.py index 778ff0e..9e85230 100755 --- a/src/webui.py +++ b/src/webui.py @@ -205,7 +205,8 @@ def optimize_training_settings_proxy( *args, **kwargs ): gr.update(value=tup[5]), gr.update(value=tup[6]), gr.update(value=tup[7]), - "\n".join(tup[8]) + gr.update(value=tup[8]), + "\n".join(tup[9]) ) def import_training_settings_proxy( voice ): @@ -247,11 +248,15 @@ def import_training_settings_proxy( voice ): print_rate = int(config['logger']['print_freq'] / steps_per_iteration) save_rate = int(config['logger']['save_checkpoint_freq'] / steps_per_iteration) + validation_rate = int(config['train']['val_freq'] / steps_per_iteration) - statedir = f'{outdir}/training_state/' # NOOO STOP MIXING YOUR CASES + half_p = config['fp16'] + bnb = True + + statedir = f'{outdir}/training_state/' resumes = [] resume_path = None - source_model = None + source_model = get_halfp_model_path() if half_p else get_model_path('autoregressive.pth') if "pretrain_model_gpt" in config['path']: source_model = config['path']['pretrain_model_gpt'] @@ -267,8 +272,6 @@ def import_training_settings_proxy( voice ): messages.append(f"Latest resume found: {resume_path}") - half_p = config['fp16'] - bnb = True if "ext" in config and "bitsandbytes" in config["ext"]: bnb = config["ext"]["bitsandbytes"] @@ -286,6 +289,7 @@ def import_training_settings_proxy( voice ): gradient_accumulation_size, print_rate, save_rate, + validation_rate, resume_path, half_p, bnb, @@ -295,7 +299,7 @@ def import_training_settings_proxy( voice ): ) -def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, workers, source_model, voice ): +def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, validation_rate, resume_path, half_p, bnb, workers, source_model, voice ): name = f"{voice}-finetune" dataset_name = f"{voice}-train" dataset_path = f"./training/{voice}/train.txt" @@ -312,7 +316,7 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear print_rate = int(print_rate * iterations / epochs) save_rate = int(save_rate * iterations / epochs) - validation_rate = save_rate + validation_rate = int(validation_rate * iterations / epochs) if iterations % save_rate != 0: adjustment = int(iterations / save_rate) * save_rate @@ -497,7 +501,9 @@ def setup_gradio(): gr.Textbox(label="Language", value="en"), gr.Checkbox(label="Skip Already Transcribed", value=False) ] - prepare_dataset_button = gr.Button(value="Prepare") + transcribe_button = gr.Button(value="Transcribe") + validation_text_cull_size = gr.Number(label="Validation Text Length Cull Size", value=12, precision=0) + prepare_validation_button = gr.Button(value="Prepare Validation") with gr.Column(): prepare_dataset_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) with gr.Tab("Generate Configuration"): @@ -524,6 +530,7 @@ def setup_gradio(): training_settings = training_settings + [ gr.Number(label="Print Frequency (in epochs)", value=5, precision=0), gr.Number(label="Save Frequency (in epochs)", value=5, precision=0), + gr.Number(label="Validation Frequency (in epochs)", value=5, precision=0), ] training_settings = training_settings + [ gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"), @@ -823,11 +830,19 @@ def setup_gradio(): ], outputs=training_output #console_output ) - prepare_dataset_button.click( + transcribe_button.click( prepare_dataset_proxy, inputs=dataset_settings, outputs=prepare_dataset_output #console_output ) + prepare_validation_button.click( + prepare_validation_dataset, + inputs=[ + dataset_settings[0], + validation_text_cull_size, + ], + outputs=prepare_dataset_output #console_output + ) refresh_dataset_list.click( lambda: gr.update(choices=get_dataset_list()), inputs=None, @@ -835,11 +850,11 @@ def setup_gradio(): ) optimize_yaml_button.click(optimize_training_settings_proxy, inputs=training_settings, - outputs=training_settings[1:9] + [save_yaml_output] #console_output + outputs=training_settings[1:10] + [save_yaml_output] #console_output ) import_dataset_button.click(import_training_settings_proxy, inputs=dataset_list_dropdown, - outputs=training_settings[:13] + [save_yaml_output] #console_output + outputs=training_settings[:14] + [save_yaml_output] #console_output ) save_yaml_button.click(save_training_settings_proxy, inputs=training_settings,