forked from mrq/ai-voice-cloning
made validation working (will document later)
This commit is contained in:
parent
a7e0dc9127
commit
b4098dca73
|
@ -28,8 +28,8 @@ datasets:
|
|||
load_aligned_codes: False
|
||||
val: # I really do not care about validation right now
|
||||
name: ${validation_name}
|
||||
n_workers: 1
|
||||
batch_size: 1
|
||||
n_workers: ${workers}
|
||||
batch_size: ${batch_size}
|
||||
mode: paired_voice_audio
|
||||
path: ${validation_path}
|
||||
fetcher_mode: ['lj']
|
||||
|
@ -131,13 +131,8 @@ train:
|
|||
lr_gamma: 0.5
|
||||
|
||||
eval:
|
||||
pure: True
|
||||
output_state: gen
|
||||
injectors:
|
||||
gen_inj_eval:
|
||||
type: generator
|
||||
generator: generator
|
||||
in: hq
|
||||
out: [gen, codebook_commitment_loss]
|
||||
|
||||
logger:
|
||||
print_freq: ${print_rate}
|
||||
|
|
47
src/utils.py
47
src/utils.py
|
@ -684,7 +684,7 @@ class TrainingState():
|
|||
except Exception as e:
|
||||
use_tensorboard = False
|
||||
|
||||
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']
|
||||
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total', 'val_loss_text_ce', 'val_loss_mel_ce']
|
||||
infos = {}
|
||||
highest_step = self.last_info_check_at
|
||||
|
||||
|
@ -1220,6 +1220,44 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
|
|||
|
||||
return f"Processed dataset to: {outdir}\n{joined}"
|
||||
|
||||
def prepare_validation_dataset( voice, text_length ):
|
||||
indir = f'./training/{voice}/'
|
||||
infile = f'{indir}/dataset.txt'
|
||||
if not os.path.exists(infile):
|
||||
infile = f'{indir}/train.txt'
|
||||
with open(f'{indir}/train.txt', 'r', encoding="utf-8") as src:
|
||||
with open(f'{indir}/dataset.txt', 'w', encoding="utf-8") as dst:
|
||||
dst.write(src.read())
|
||||
|
||||
if not os.path.exists(infile):
|
||||
raise Exception(f"Missing dataset: {infile}")
|
||||
|
||||
with open(infile, 'r', encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
validation = []
|
||||
training = []
|
||||
|
||||
for line in lines:
|
||||
split = line.split("|")
|
||||
filename = split[0]
|
||||
text = split[1]
|
||||
|
||||
if len(text) < text_length:
|
||||
validation.append(line.strip())
|
||||
else:
|
||||
training.append(line.strip())
|
||||
|
||||
with open(f'{indir}/train.txt', 'w', encoding="utf-8") as f:
|
||||
f.write("\n".join(training))
|
||||
|
||||
with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f:
|
||||
f.write("\n".join(validation))
|
||||
|
||||
msg = f"Culled {len(validation)} lines"
|
||||
print(msg)
|
||||
return msg
|
||||
|
||||
def calc_iterations( epochs, lines, batch_size ):
|
||||
iterations = int(epochs * lines / float(batch_size))
|
||||
return iterations
|
||||
|
@ -1227,7 +1265,7 @@ def calc_iterations( epochs, lines, batch_size ):
|
|||
def schedule_learning_rate( iterations, schedule=EPOCH_SCHEDULE ):
|
||||
return [int(iterations * d) for d in schedule]
|
||||
|
||||
def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, workers, source_model, voice ):
|
||||
def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, validation_rate, resume_path, half_p, bnb, workers, source_model, voice ):
|
||||
name = f"{voice}-finetune"
|
||||
dataset_path = f"./training/{voice}/train.txt"
|
||||
|
||||
|
@ -1271,6 +1309,10 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
|
|||
save_rate = epochs
|
||||
messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {save_rate}")
|
||||
|
||||
if epochs < validation_rate:
|
||||
validation_rate = epochs
|
||||
messages.append(f"Validation rate is too small for the given iteration step, clamping validation rate to: {validation_rate}")
|
||||
|
||||
if resume_path and not os.path.exists(resume_path):
|
||||
resume_path = None
|
||||
messages.append("Resume path specified, but does not exist. Disabling...")
|
||||
|
@ -1297,6 +1339,7 @@ def optimize_training_settings( epochs, learning_rate, text_ce_lr_weight, learni
|
|||
gradient_accumulation_size,
|
||||
print_rate,
|
||||
save_rate,
|
||||
validation_rate,
|
||||
resume_path,
|
||||
messages
|
||||
)
|
||||
|
|
37
src/webui.py
37
src/webui.py
|
@ -205,7 +205,8 @@ def optimize_training_settings_proxy( *args, **kwargs ):
|
|||
gr.update(value=tup[5]),
|
||||
gr.update(value=tup[6]),
|
||||
gr.update(value=tup[7]),
|
||||
"\n".join(tup[8])
|
||||
gr.update(value=tup[8]),
|
||||
"\n".join(tup[9])
|
||||
)
|
||||
|
||||
def import_training_settings_proxy( voice ):
|
||||
|
@ -247,11 +248,15 @@ def import_training_settings_proxy( voice ):
|
|||
|
||||
print_rate = int(config['logger']['print_freq'] / steps_per_iteration)
|
||||
save_rate = int(config['logger']['save_checkpoint_freq'] / steps_per_iteration)
|
||||
validation_rate = int(config['train']['val_freq'] / steps_per_iteration)
|
||||
|
||||
statedir = f'{outdir}/training_state/' # NOOO STOP MIXING YOUR CASES
|
||||
half_p = config['fp16']
|
||||
bnb = True
|
||||
|
||||
statedir = f'{outdir}/training_state/'
|
||||
resumes = []
|
||||
resume_path = None
|
||||
source_model = None
|
||||
source_model = get_halfp_model_path() if half_p else get_model_path('autoregressive.pth')
|
||||
|
||||
if "pretrain_model_gpt" in config['path']:
|
||||
source_model = config['path']['pretrain_model_gpt']
|
||||
|
@ -267,8 +272,6 @@ def import_training_settings_proxy( voice ):
|
|||
messages.append(f"Latest resume found: {resume_path}")
|
||||
|
||||
|
||||
half_p = config['fp16']
|
||||
bnb = True
|
||||
|
||||
if "ext" in config and "bitsandbytes" in config["ext"]:
|
||||
bnb = config["ext"]["bitsandbytes"]
|
||||
|
@ -286,6 +289,7 @@ def import_training_settings_proxy( voice ):
|
|||
gradient_accumulation_size,
|
||||
print_rate,
|
||||
save_rate,
|
||||
validation_rate,
|
||||
resume_path,
|
||||
half_p,
|
||||
bnb,
|
||||
|
@ -295,7 +299,7 @@ def import_training_settings_proxy( voice ):
|
|||
)
|
||||
|
||||
|
||||
def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, resume_path, half_p, bnb, workers, source_model, voice ):
|
||||
def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, learning_rate_schedule, batch_size, gradient_accumulation_size, print_rate, save_rate, validation_rate, resume_path, half_p, bnb, workers, source_model, voice ):
|
||||
name = f"{voice}-finetune"
|
||||
dataset_name = f"{voice}-train"
|
||||
dataset_path = f"./training/{voice}/train.txt"
|
||||
|
@ -312,7 +316,7 @@ def save_training_settings_proxy( epochs, learning_rate, text_ce_lr_weight, lear
|
|||
|
||||
print_rate = int(print_rate * iterations / epochs)
|
||||
save_rate = int(save_rate * iterations / epochs)
|
||||
validation_rate = save_rate
|
||||
validation_rate = int(validation_rate * iterations / epochs)
|
||||
|
||||
if iterations % save_rate != 0:
|
||||
adjustment = int(iterations / save_rate) * save_rate
|
||||
|
@ -497,7 +501,9 @@ def setup_gradio():
|
|||
gr.Textbox(label="Language", value="en"),
|
||||
gr.Checkbox(label="Skip Already Transcribed", value=False)
|
||||
]
|
||||
prepare_dataset_button = gr.Button(value="Prepare")
|
||||
transcribe_button = gr.Button(value="Transcribe")
|
||||
validation_text_cull_size = gr.Number(label="Validation Text Length Cull Size", value=12, precision=0)
|
||||
prepare_validation_button = gr.Button(value="Prepare Validation")
|
||||
with gr.Column():
|
||||
prepare_dataset_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||
with gr.Tab("Generate Configuration"):
|
||||
|
@ -524,6 +530,7 @@ def setup_gradio():
|
|||
training_settings = training_settings + [
|
||||
gr.Number(label="Print Frequency (in epochs)", value=5, precision=0),
|
||||
gr.Number(label="Save Frequency (in epochs)", value=5, precision=0),
|
||||
gr.Number(label="Validation Frequency (in epochs)", value=5, precision=0),
|
||||
]
|
||||
training_settings = training_settings + [
|
||||
gr.Textbox(label="Resume State Path", placeholder="./training/${voice}-finetune/training_state/${last_state}.state"),
|
||||
|
@ -823,11 +830,19 @@ def setup_gradio():
|
|||
],
|
||||
outputs=training_output #console_output
|
||||
)
|
||||
prepare_dataset_button.click(
|
||||
transcribe_button.click(
|
||||
prepare_dataset_proxy,
|
||||
inputs=dataset_settings,
|
||||
outputs=prepare_dataset_output #console_output
|
||||
)
|
||||
prepare_validation_button.click(
|
||||
prepare_validation_dataset,
|
||||
inputs=[
|
||||
dataset_settings[0],
|
||||
validation_text_cull_size,
|
||||
],
|
||||
outputs=prepare_dataset_output #console_output
|
||||
)
|
||||
refresh_dataset_list.click(
|
||||
lambda: gr.update(choices=get_dataset_list()),
|
||||
inputs=None,
|
||||
|
@ -835,11 +850,11 @@ def setup_gradio():
|
|||
)
|
||||
optimize_yaml_button.click(optimize_training_settings_proxy,
|
||||
inputs=training_settings,
|
||||
outputs=training_settings[1:9] + [save_yaml_output] #console_output
|
||||
outputs=training_settings[1:10] + [save_yaml_output] #console_output
|
||||
)
|
||||
import_dataset_button.click(import_training_settings_proxy,
|
||||
inputs=dataset_list_dropdown,
|
||||
outputs=training_settings[:13] + [save_yaml_output] #console_output
|
||||
outputs=training_settings[:14] + [save_yaml_output] #console_output
|
||||
)
|
||||
save_yaml_button.click(save_training_settings_proxy,
|
||||
inputs=training_settings,
|
||||
|
|
Loading…
Reference in New Issue
Block a user