added helper script to cull short enough lines from training set as a validation set (if it yields good results doing validation during training, i'll add it to the web ui)

This commit is contained in:
mrq 2023-03-07 20:16:49 +00:00
parent 7f89e8058a
commit fe8bf7a9d1
3 changed files with 46 additions and 11 deletions

35
src/cull_dataset.py Executable file
View File

@ -0,0 +1,35 @@
import os
import sys
indir = f'./training/{sys.argv[1]}/'
cap = int(sys.argv[2])
if not os.path.isdir(indir):
raise Exception(f"Invalid directory: {indir}")
if not os.path.exists(f'{indir}/train.txt'):
raise Exception(f"Missing dataset: {indir}/train.txt")
with open(f'{indir}/train.txt', 'r', encoding="utf-8") as f:
lines = f.readlines()
validation = []
training = []
for line in lines:
split = line.split("|")
filename = split[0]
text = split[1]
if len(text) < cap:
validation.append(line.strip())
else:
training.append(line.strip())
with open(f'{indir}/train_culled.txt', 'w', encoding="utf-8") as f:
f.write("\n".join(training))
with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f:
f.write("\n".join(validation))
print(f"Culled {len(validation)} lines")

View File

@ -605,7 +605,7 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog
# superfluous, but it cleans up some things # superfluous, but it cleans up some things
class TrainingState(): class TrainingState():
def __init__(self, config_path, keep_x_past_datasets=0, start=True, gpus=1): def __init__(self, config_path, keep_x_past_checkpoints=0, start=True, gpus=1):
# parse config to get its iteration # parse config to get its iteration
with open(config_path, 'r') as file: with open(config_path, 'r') as file:
self.config = yaml.safe_load(file) self.config = yaml.safe_load(file)
@ -664,8 +664,8 @@ class TrainingState():
self.loss_milestones = [ 1.0, 0.15, 0.05 ] self.loss_milestones = [ 1.0, 0.15, 0.05 ]
self.load_losses() self.load_losses()
if keep_x_past_datasets > 0: if keep_x_past_checkpoints > 0:
self.cleanup_old(keep=keep_x_past_datasets) self.cleanup_old(keep=keep_x_past_checkpoints)
if start: if start:
self.spawn_process(config_path=config_path, gpus=gpus) self.spawn_process(config_path=config_path, gpus=gpus)
@ -772,7 +772,7 @@ class TrainingState():
print("Removing", path) print("Removing", path)
os.remove(path) os.remove(path)
def parse(self, line, verbose=False, keep_x_past_datasets=0, buffer_size=8, progress=None ): def parse(self, line, verbose=False, keep_x_past_checkpoints=0, buffer_size=8, progress=None ):
self.buffer.append(f'{line}') self.buffer.append(f'{line}')
should_return = False should_return = False
@ -830,7 +830,7 @@ class TrainingState():
print(f'{"{:.3f}".format(percent*100)}% {message}') print(f'{"{:.3f}".format(percent*100)}% {message}')
self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}') self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
self.cleanup_old(keep=keep_x_past_datasets) self.cleanup_old(keep=keep_x_past_checkpoints)
if line.find('%|') > 0: if line.find('%|') > 0:
match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line) match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
@ -986,7 +986,7 @@ class TrainingState():
message, message,
) )
def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)): def run_training(config_path, verbose=False, gpus=1, keep_x_past_checkpoints=0, progress=gr.Progress(track_tqdm=True)):
global training_state global training_state
if training_state and training_state.process: if training_state and training_state.process:
return "Training already in progress" return "Training already in progress"
@ -1008,13 +1008,13 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro
unload_whisper() unload_whisper()
unload_voicefixer() unload_voicefixer()
training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets, gpus=gpus) training_state = TrainingState(config_path=config_path, keep_x_past_checkpoints=keep_x_past_checkpoints, gpus=gpus)
for line in iter(training_state.process.stdout.readline, ""): for line in iter(training_state.process.stdout.readline, ""):
if training_state.killed: if training_state.killed:
return return
result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_datasets=keep_x_past_datasets, progress=progress ) result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_checkpoints=keep_x_past_checkpoints, progress=progress )
print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}") print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
if result: if result:
yield result yield result
@ -1164,7 +1164,7 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
for line in parsed_list: for line in parsed_list:
match = re.findall(r"^(.+?)_\d+\.wav$", line.split("|")[0]) match = re.findall(r"^(.+?)_\d+\.wav$", line.split("|")[0])
print(match)
if match is None or len(match) == 0: if match is None or len(match) == 0:
continue continue

View File

@ -559,7 +559,7 @@ def setup_gradio():
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True) verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
with gr.Row(): with gr.Row():
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1) training_keep_x_past_checkpoints = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
training_gpu_count = gr.Number(label="GPUs", value=get_device_count()) training_gpu_count = gr.Number(label="GPUs", value=get_device_count())
with gr.Row(): with gr.Row():
start_training_button = gr.Button(value="Train") start_training_button = gr.Button(value="Train")
@ -777,7 +777,7 @@ def setup_gradio():
training_configs, training_configs,
verbose_training, verbose_training,
training_gpu_count, training_gpu_count,
training_keep_x_past_datasets, training_keep_x_past_checkpoints,
], ],
outputs=[ outputs=[
training_output, training_output,