made validation working (will document later)

This commit is contained in:
mrq 2023-03-08 02:58:00 +00:00
parent a7e0dc9127
commit b4098dca73
3 changed files with 74 additions and 21 deletions

View File

@ -28,8 +28,8 @@ datasets:
load_aligned_codes: False load_aligned_codes: False
val: # I really do not care about validation right now val: # I really do not care about validation right now
name: ${validation_name} name: ${validation_name}
n_workers: 1 n_workers: ${workers}
batch_size: 1 batch_size: ${batch_size}
mode: paired_voice_audio mode: paired_voice_audio
path: ${validation_path} path: ${validation_path}
fetcher_mode: ['lj'] fetcher_mode: ['lj']
@ -131,13 +131,8 @@ train:
lr_gamma: 0.5 lr_gamma: 0.5
eval: eval:
pure: True
output_state: gen output_state: gen
injectors:
gen_inj_eval:
type: generator
generator: generator
in: hq
out: [gen, codebook_commitment_loss]
logger: logger:
print_freq: ${print_rate} print_freq: ${print_rate}

View File

@ -684,7 +684,7 @@ class TrainingState():
except Exception as e: except Exception as e:
use_tensorboard = False use_tensorboard = False
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total'] keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total', 'val_loss_text_ce', 'val_loss_mel_ce']
infos = {} infos = {}
highest_step = self.last_info_check_at highest_step = self.last_info_check_at
@ -1220,6 +1220,44 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
return f"Processed dataset to: {outdir}\n{joined}" return f"Processed dataset to: {outdir}\n{joined}"
def prepare_validation_dataset( voice, text_length ):
indir = f'./training/{voice}/'
infile = f'{indir}/dataset.txt'
if not os.path.exists(infile):
infile = f'{indir}/train.txt'
with open(f'{indir}/train.txt', 'r', encoding="utf-8") as src:
with open(f'{indir}/dataset.txt', 'w', encoding="utf-8") as dst:
dst.write(src.read())
if not os.path.exists(infile):
raise Exception(f"Missing dataset: {infile}")
with open(infile, 'r', encoding="utf-8") as f:
lines = f.readlines()
validation = []
training = []
for line in lines:
split = line.split("|")
filename = split[0]
text = split[1]
if len(text) < text_length:
validation.append(line.strip())
else:
training.append(line.strip())
with open(f'{indir}/train.txt', 'w', encoding="utf-8") as f:
f.write("\n".join(training))
with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f:
f.write("\n".join(validation))
msg = f"Culled {len(validation)} lines"
print(msg)
return msg
def calc_iterations( epochs, lines, batch_size ): def calc_iterations( epochs, lines, batch_size ):
iterations = int(epochs * lines / float(batch_size)) iterations = int(epochs * lines / float(batch_size))
return iterations return iterations
@ -1227,7 +1265,7 @@ def calc_iterations( epochs, lines, batch_size ):
def schedule_learning_rate( iterations, schedule=EPOCH_SCHEDULE ): def schedule_learning_rate( iterations, schedule=EPOCH_SCHEDULE ):
return [int(iterations * d) for d in schedule] return [int(iterations * d) for d in schedule]
def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, workers, source_model, voice ): def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, validation_rate, resume_path, half_p, bnb, workers, source_model, voice ):
name = f"{voice}-finetune" name = f"{voice}-finetune"
dataset_path = f"./training/{voice}/train.txt" dataset_path = f"./training/{voice}/train.txt"
@ -1271,6 +1309,10 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
save_rate = epochs save_rate = epochs
messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {save_rate}") messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {save_rate}")
if epochs < validation_rate:
validation_rate = epochs
messages.append(f"Validation rate is too small for the given iteration step, clamping validation rate to: {validation_rate}")
if resume_path and not os.path.exists(resume_path): if resume_path and not os.path.exists(resume_path):
resume_path = None resume_path = None
messages.append("Resume path specified, but does not exist. Disabling...") messages.append("Resume path specified, but does not exist. Disabling...")
@ -1297,6 +1339,7 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
gradient_accumulation_size, gradient_accumulation_size,
print_rate, print_rate,
save_rate, save_rate,
validation_rate,
resume_path, resume_path,
messages messages
) )

View File

@ -205,7 +205,8 @@ def optimize_training_settings_proxy( *args, **kwargs ):
gr.update(value=tup[5]), gr.update(value=tup[5]),
gr.update(value=tup[6]), gr.update(value=tup[6]),
gr.update(value=tup[7]), gr.update(value=tup[7]),
"\n".join(tup[8]) gr.update(value=tup[8]),
"\n".join(tup[9])
) )
def import_training_settings_proxy( voice ): def import_training_settings_proxy( voice ):
@ -247,11 +248,15 @@ def import_training_settings_proxy( voice ):
print_rate = int(config['logger']['print_freq'] / steps_per_iteration) print_rate = int(config['logger']['print_freq'] / steps_per_iteration)
save_rate = int(config['logger']['save_checkpoint_freq'] / steps_per_iteration) save_rate = int(config['logger']['save_checkpoint_freq'] / steps_per_iteration)
validation_rate = int(config['train']['val_freq'] / steps_per_iteration)
statedir = f'{outdir}/training_state/' # NOOO STOP MIXING YOUR CASES half_p = config['fp16']
bnb = True
statedir = f'{outdir}/training_state/'
resumes = [] resumes = []
resume_path = None resume_path = None
source_model = None source_model = get_halfp_model_path() if half_p else get_model_path('autoregressive.pth')
if "pretrain_model_gpt" in config['path']: if "pretrain_model_gpt" in config['path']:
source_model = config['path']['pretrain_model_gpt'] source_model = config['path']['pretrain_model_gpt']
@ -267,8 +272,6 @@ def import_training_settings_proxy( voice ):
messages.append(f"Latest resume found: {resume_path}") messages.append(f"Latest resume found: {resume_path}")
half_p = config['fp16']
bnb = True
if "ext" in config and "bitsandbytes" in config["ext"]: if "ext" in config and "bitsandbytes" in config["ext"]:
bnb = config["ext"]["bitsandbytes"] bnb = config["ext"]["bitsandbytes"]
@ -286,6 +289,7 @@ def import_training_settings_proxy( voice ):
gradient_accumulation_size, gradient_accumulation_size,
print_rate, print_rate,
save_rate, save_rate,
validation_rate,
resume_path, resume_path,
half_p, half_p,
bnb, bnb,
@ -295,7 +299,7 @@ def import_training_settings_proxy( voice ):
) )
def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, workers, source_model, voice ): def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, validation_rate, resume_path, half_p, bnb, workers, source_model, voice ):
name = f"{voice}-finetune" name = f"{voice}-finetune"
dataset_name = f"{voice}-train" dataset_name = f"{voice}-train"
dataset_path = f"./training/{voice}/train.txt" dataset_path = f"./training/{voice}/train.txt"
@ -312,7 +316,7 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
print_rate = int(print_rate * iterations / epochs) print_rate = int(print_rate * iterations / epochs)
save_rate = int(save_rate * iterations / epochs) save_rate = int(save_rate * iterations / epochs)
validation_rate = save_rate validation_rate = int(validation_rate * iterations / epochs)
if iterations % save_rate != 0: if iterations % save_rate != 0:
adjustment = int(iterations / save_rate) * save_rate adjustment = int(iterations / save_rate) * save_rate
@ -497,7 +501,9 @@ def setup_gradio():
gr.Textbox(label="Language", value="en"), gr.Textbox(label="Language", value="en"),
gr.Checkbox(label="Skip Already Transcribed", value=False) gr.Checkbox(label="Skip Already Transcribed", value=False)
] ]
prepare_dataset_button = gr.Button(value="Prepare") transcribe_button = gr.Button(value="Transcribe")
validation_text_cull_size = gr.Number(label="Validation Text Length Cull Size", value=12, precision=0)
prepare_validation_button = gr.Button(value="Prepare Validation")
with gr.Column(): with gr.Column():
prepare_dataset_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8) prepare_dataset_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
with gr.Tab("Generate Configuration"): with gr.Tab("Generate Configuration"):
@ -524,6 +530,7 @@ def setup_gradio():
training_settings = training_settings + [ training_settings = training_settings + [
gr.Number(label="Print Frequency (in epochs)", value=5, precision=0), gr.Number(label="Print Frequency (in epochs)", value=5, precision=0),
gr.Number(label="Save Frequency (in epochs)", value=5, precision=0), gr.Number(label="Save Frequency (in epochs)", value=5, precision=0),
gr.Number(label="Validation Frequency (in epochs)", value=5, precision=0),
] ]
training_settings = training_settings + [ training_settings = training_settings + [
gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"), gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"),
@ -823,11 +830,19 @@ def setup_gradio():
], ],
outputs=training_output #console_output outputs=training_output #console_output
) )
prepare_dataset_button.click( transcribe_button.click(
prepare_dataset_proxy, prepare_dataset_proxy,
inputs=dataset_settings, inputs=dataset_settings,
outputs=prepare_dataset_output #console_output outputs=prepare_dataset_output #console_output
) )
prepare_validation_button.click(
prepare_validation_dataset,
inputs=[
dataset_settings[0],
validation_text_cull_size,
],
outputs=prepare_dataset_output #console_output
)
refresh_dataset_list.click( refresh_dataset_list.click(
lambda: gr.update(choices=get_dataset_list()), lambda: gr.update(choices=get_dataset_list()),
inputs=None, inputs=None,
@ -835,11 +850,11 @@ def setup_gradio():
) )
optimize_yaml_button.click(optimize_training_settings_proxy, optimize_yaml_button.click(optimize_training_settings_proxy,
inputs=training_settings, inputs=training_settings,
outputs=training_settings[1:9] + [save_yaml_output] #console_output outputs=training_settings[1:10] + [save_yaml_output] #console_output
) )
import_dataset_button.click(import_training_settings_proxy, import_dataset_button.click(import_training_settings_proxy,
inputs=dataset_list_dropdown, inputs=dataset_list_dropdown,
outputs=training_settings[:13] + [save_yaml_output] #console_output outputs=training_settings[:14] + [save_yaml_output] #console_output
) )
save_yaml_button.click(save_training_settings_proxy, save_yaml_button.click(save_training_settings_proxy,
inputs=training_settings, inputs=training_settings,