made validation working (will document later)
This commit is contained in:
parent
a7e0dc9127
commit
b4098dca73
|
@ -28,8 +28,8 @@ datasets:
|
||||||
load_aligned_codes: False
|
load_aligned_codes: False
|
||||||
val: # I really do not care about validation right now
|
val: # I really do not care about validation right now
|
||||||
name: ${validation_name}
|
name: ${validation_name}
|
||||||
n_workers: 1
|
n_workers: ${workers}
|
||||||
batch_size: 1
|
batch_size: ${batch_size}
|
||||||
mode: paired_voice_audio
|
mode: paired_voice_audio
|
||||||
path: ${validation_path}
|
path: ${validation_path}
|
||||||
fetcher_mode: ['lj']
|
fetcher_mode: ['lj']
|
||||||
|
@ -131,13 +131,8 @@ train:
|
||||||
lr_gamma: 0.5
|
lr_gamma: 0.5
|
||||||
|
|
||||||
eval:
|
eval:
|
||||||
|
pure: True
|
||||||
output_state: gen
|
output_state: gen
|
||||||
injectors:
|
|
||||||
gen_inj_eval:
|
|
||||||
type: generator
|
|
||||||
generator: generator
|
|
||||||
in: hq
|
|
||||||
out: [gen, codebook_commitment_loss]
|
|
||||||
|
|
||||||
logger:
|
logger:
|
||||||
print_freq: ${print_rate}
|
print_freq: ${print_rate}
|
||||||
|
|
47
src/utils.py
47
src/utils.py
|
@ -684,7 +684,7 @@ class TrainingState():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
use_tensorboard = False
|
use_tensorboard = False
|
||||||
|
|
||||||
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']
|
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total', 'val_loss_text_ce', 'val_loss_mel_ce']
|
||||||
infos = {}
|
infos = {}
|
||||||
highest_step = self.last_info_check_at
|
highest_step = self.last_info_check_at
|
||||||
|
|
||||||
|
@ -1220,6 +1220,44 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
|
||||||
|
|
||||||
return f"Processed dataset to: {outdir}\n{joined}"
|
return f"Processed dataset to: {outdir}\n{joined}"
|
||||||
|
|
||||||
|
def prepare_validation_dataset( voice, text_length ):
|
||||||
|
indir = f'./training/{voice}/'
|
||||||
|
infile = f'{indir}/dataset.txt'
|
||||||
|
if not os.path.exists(infile):
|
||||||
|
infile = f'{indir}/train.txt'
|
||||||
|
with open(f'{indir}/train.txt', 'r', encoding="utf-8") as src:
|
||||||
|
with open(f'{indir}/dataset.txt', 'w', encoding="utf-8") as dst:
|
||||||
|
dst.write(src.read())
|
||||||
|
|
||||||
|
if not os.path.exists(infile):
|
||||||
|
raise Exception(f"Missing dataset: {infile}")
|
||||||
|
|
||||||
|
with open(infile, 'r', encoding="utf-8") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
validation = []
|
||||||
|
training = []
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
split = line.split("|")
|
||||||
|
filename = split[0]
|
||||||
|
text = split[1]
|
||||||
|
|
||||||
|
if len(text) < text_length:
|
||||||
|
validation.append(line.strip())
|
||||||
|
else:
|
||||||
|
training.append(line.strip())
|
||||||
|
|
||||||
|
with open(f'{indir}/train.txt', 'w', encoding="utf-8") as f:
|
||||||
|
f.write("\n".join(training))
|
||||||
|
|
||||||
|
with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f:
|
||||||
|
f.write("\n".join(validation))
|
||||||
|
|
||||||
|
msg = f"Culled {len(validation)} lines"
|
||||||
|
print(msg)
|
||||||
|
return msg
|
||||||
|
|
||||||
def calc_iterations( epochs, lines, batch_size ):
|
def calc_iterations( epochs, lines, batch_size ):
|
||||||
iterations = int(epochs * lines / float(batch_size))
|
iterations = int(epochs * lines / float(batch_size))
|
||||||
return iterations
|
return iterations
|
||||||
|
@ -1227,7 +1265,7 @@ def calc_iterations( epochs, lines, batch_size ):
|
||||||
def schedule_learning_rate( iterations, schedule=EPOCH_SCHEDULE ):
|
def schedule_learning_rate( iterations, schedule=EPOCH_SCHEDULE ):
|
||||||
return [int(iterations * d) for d in schedule]
|
return [int(iterations * d) for d in schedule]
|
||||||
|
|
||||||
def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, workers, source_model, voice ):
|
def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, validation_rate, resume_path, half_p, bnb, workers, source_model, voice ):
|
||||||
name = f"{voice}-finetune"
|
name = f"{voice}-finetune"
|
||||||
dataset_path = f"./training/{voice}/train.txt"
|
dataset_path = f"./training/{voice}/train.txt"
|
||||||
|
|
||||||
|
@ -1271,6 +1309,10 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
|
||||||
save_rate = epochs
|
save_rate = epochs
|
||||||
messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {save_rate}")
|
messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {save_rate}")
|
||||||
|
|
||||||
|
if epochs < validation_rate:
|
||||||
|
validation_rate = epochs
|
||||||
|
messages.append(f"Validation rate is too small for the given iteration step, clamping validation rate to: {validation_rate}")
|
||||||
|
|
||||||
if resume_path and not os.path.exists(resume_path):
|
if resume_path and not os.path.exists(resume_path):
|
||||||
resume_path = None
|
resume_path = None
|
||||||
messages.append("Resume path specified, but does not exist. Disabling...")
|
messages.append("Resume path specified, but does not exist. Disabling...")
|
||||||
|
@ -1297,6 +1339,7 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
|
||||||
gradient_accumulation_size,
|
gradient_accumulation_size,
|
||||||
print_rate,
|
print_rate,
|
||||||
save_rate,
|
save_rate,
|
||||||
|
validation_rate,
|
||||||
resume_path,
|
resume_path,
|
||||||
messages
|
messages
|
||||||
)
|
)
|
||||||
|
|
37
src/webui.py
37
src/webui.py
|
@ -205,7 +205,8 @@ def optimize_training_settings_proxy( *args, **kwargs ):
|
||||||
gr.update(value=tup[5]),
|
gr.update(value=tup[5]),
|
||||||
gr.update(value=tup[6]),
|
gr.update(value=tup[6]),
|
||||||
gr.update(value=tup[7]),
|
gr.update(value=tup[7]),
|
||||||
"\n".join(tup[8])
|
gr.update(value=tup[8]),
|
||||||
|
"\n".join(tup[9])
|
||||||
)
|
)
|
||||||
|
|
||||||
def import_training_settings_proxy( voice ):
|
def import_training_settings_proxy( voice ):
|
||||||
|
@ -247,11 +248,15 @@ def import_training_settings_proxy( voice ):
|
||||||
|
|
||||||
print_rate = int(config['logger']['print_freq'] / steps_per_iteration)
|
print_rate = int(config['logger']['print_freq'] / steps_per_iteration)
|
||||||
save_rate = int(config['logger']['save_checkpoint_freq'] / steps_per_iteration)
|
save_rate = int(config['logger']['save_checkpoint_freq'] / steps_per_iteration)
|
||||||
|
validation_rate = int(config['train']['val_freq'] / steps_per_iteration)
|
||||||
|
|
||||||
statedir = f'{outdir}/training_state/' # NOOO STOP MIXING YOUR CASES
|
half_p = config['fp16']
|
||||||
|
bnb = True
|
||||||
|
|
||||||
|
statedir = f'{outdir}/training_state/'
|
||||||
resumes = []
|
resumes = []
|
||||||
resume_path = None
|
resume_path = None
|
||||||
source_model = None
|
source_model = get_halfp_model_path() if half_p else get_model_path('autoregressive.pth')
|
||||||
|
|
||||||
if "pretrain_model_gpt" in config['path']:
|
if "pretrain_model_gpt" in config['path']:
|
||||||
source_model = config['path']['pretrain_model_gpt']
|
source_model = config['path']['pretrain_model_gpt']
|
||||||
|
@ -267,8 +272,6 @@ def import_training_settings_proxy( voice ):
|
||||||
messages.append(f"Latest resume found: {resume_path}")
|
messages.append(f"Latest resume found: {resume_path}")
|
||||||
|
|
||||||
|
|
||||||
half_p = config['fp16']
|
|
||||||
bnb = True
|
|
||||||
|
|
||||||
if "ext" in config and "bitsandbytes" in config["ext"]:
|
if "ext" in config and "bitsandbytes" in config["ext"]:
|
||||||
bnb = config["ext"]["bitsandbytes"]
|
bnb = config["ext"]["bitsandbytes"]
|
||||||
|
@ -286,6 +289,7 @@ def import_training_settings_proxy( voice ):
|
||||||
gradient_accumulation_size,
|
gradient_accumulation_size,
|
||||||
print_rate,
|
print_rate,
|
||||||
save_rate,
|
save_rate,
|
||||||
|
validation_rate,
|
||||||
resume_path,
|
resume_path,
|
||||||
half_p,
|
half_p,
|
||||||
bnb,
|
bnb,
|
||||||
|
@ -295,7 +299,7 @@ def import_training_settings_proxy( voice ):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, workers, source_model, voice ):
|
def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, validation_rate, resume_path, half_p, bnb, workers, source_model, voice ):
|
||||||
name = f"{voice}-finetune"
|
name = f"{voice}-finetune"
|
||||||
dataset_name = f"{voice}-train"
|
dataset_name = f"{voice}-train"
|
||||||
dataset_path = f"./training/{voice}/train.txt"
|
dataset_path = f"./training/{voice}/train.txt"
|
||||||
|
@ -312,7 +316,7 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
|
||||||
|
|
||||||
print_rate = int(print_rate * iterations / epochs)
|
print_rate = int(print_rate * iterations / epochs)
|
||||||
save_rate = int(save_rate * iterations / epochs)
|
save_rate = int(save_rate * iterations / epochs)
|
||||||
validation_rate = save_rate
|
validation_rate = int(validation_rate * iterations / epochs)
|
||||||
|
|
||||||
if iterations % save_rate != 0:
|
if iterations % save_rate != 0:
|
||||||
adjustment = int(iterations / save_rate) * save_rate
|
adjustment = int(iterations / save_rate) * save_rate
|
||||||
|
@ -497,7 +501,9 @@ def setup_gradio():
|
||||||
gr.Textbox(label="Language", value="en"),
|
gr.Textbox(label="Language", value="en"),
|
||||||
gr.Checkbox(label="Skip Already Transcribed", value=False)
|
gr.Checkbox(label="Skip Already Transcribed", value=False)
|
||||||
]
|
]
|
||||||
prepare_dataset_button = gr.Button(value="Prepare")
|
transcribe_button = gr.Button(value="Transcribe")
|
||||||
|
validation_text_cull_size = gr.Number(label="Validation Text Length Cull Size", value=12, precision=0)
|
||||||
|
prepare_validation_button = gr.Button(value="Prepare Validation")
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
prepare_dataset_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
prepare_dataset_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||||
with gr.Tab("Generate Configuration"):
|
with gr.Tab("Generate Configuration"):
|
||||||
|
@ -524,6 +530,7 @@ def setup_gradio():
|
||||||
training_settings = training_settings + [
|
training_settings = training_settings + [
|
||||||
gr.Number(label="Print Frequency (in epochs)", value=5, precision=0),
|
gr.Number(label="Print Frequency (in epochs)", value=5, precision=0),
|
||||||
gr.Number(label="Save Frequency (in epochs)", value=5, precision=0),
|
gr.Number(label="Save Frequency (in epochs)", value=5, precision=0),
|
||||||
|
gr.Number(label="Validation Frequency (in epochs)", value=5, precision=0),
|
||||||
]
|
]
|
||||||
training_settings = training_settings + [
|
training_settings = training_settings + [
|
||||||
gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"),
|
gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"),
|
||||||
|
@ -823,11 +830,19 @@ def setup_gradio():
|
||||||
],
|
],
|
||||||
outputs=training_output #console_output
|
outputs=training_output #console_output
|
||||||
)
|
)
|
||||||
prepare_dataset_button.click(
|
transcribe_button.click(
|
||||||
prepare_dataset_proxy,
|
prepare_dataset_proxy,
|
||||||
inputs=dataset_settings,
|
inputs=dataset_settings,
|
||||||
outputs=prepare_dataset_output #console_output
|
outputs=prepare_dataset_output #console_output
|
||||||
)
|
)
|
||||||
|
prepare_validation_button.click(
|
||||||
|
prepare_validation_dataset,
|
||||||
|
inputs=[
|
||||||
|
dataset_settings[0],
|
||||||
|
validation_text_cull_size,
|
||||||
|
],
|
||||||
|
outputs=prepare_dataset_output #console_output
|
||||||
|
)
|
||||||
refresh_dataset_list.click(
|
refresh_dataset_list.click(
|
||||||
lambda: gr.update(choices=get_dataset_list()),
|
lambda: gr.update(choices=get_dataset_list()),
|
||||||
inputs=None,
|
inputs=None,
|
||||||
|
@ -835,11 +850,11 @@ def setup_gradio():
|
||||||
)
|
)
|
||||||
optimize_yaml_button.click(optimize_training_settings_proxy,
|
optimize_yaml_button.click(optimize_training_settings_proxy,
|
||||||
inputs=training_settings,
|
inputs=training_settings,
|
||||||
outputs=training_settings[1:9] + [save_yaml_output] #console_output
|
outputs=training_settings[1:10] + [save_yaml_output] #console_output
|
||||||
)
|
)
|
||||||
import_dataset_button.click(import_training_settings_proxy,
|
import_dataset_button.click(import_training_settings_proxy,
|
||||||
inputs=dataset_list_dropdown,
|
inputs=dataset_list_dropdown,
|
||||||
outputs=training_settings[:13] + [save_yaml_output] #console_output
|
outputs=training_settings[:14] + [save_yaml_output] #console_output
|
||||||
)
|
)
|
||||||
save_yaml_button.click(save_training_settings_proxy,
|
save_yaml_button.click(save_training_settings_proxy,
|
||||||
inputs=training_settings,
|
inputs=training_settings,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user