From fe8bf7a9d1bf800c06d31c2404e759874c574e9f Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 7 Mar 2023 20:16:49 +0000 Subject: [PATCH] added helper script to cull short enough lines from training set as a validation set (if it yields good results doing validation during training, i'll add it to the web ui) --- src/cull_dataset.py | 35 +++++++++++++++++++++++++++++++++++ src/utils.py | 18 +++++++++--------- src/webui.py | 4 ++-- 3 files changed, 46 insertions(+), 11 deletions(-) create mode 100755 src/cull_dataset.py diff --git a/src/cull_dataset.py b/src/cull_dataset.py new file mode 100755 index 0000000..0572405 --- /dev/null +++ b/src/cull_dataset.py @@ -0,0 +1,35 @@ +import os +import sys + +indir = f'./training/{sys.argv[1]}/' +cap = int(sys.argv[2]) + +if not os.path.isdir(indir): + raise Exception(f"Invalid directory: {indir}") + +if not os.path.exists(f'{indir}/train.txt'): + raise Exception(f"Missing dataset: {indir}/train.txt") + +with open(f'{indir}/train.txt', 'r', encoding="utf-8") as f: + lines = f.readlines() + +validation = [] +training = [] + +for line in lines: + split = line.split("|") + filename = split[0] + text = split[1] + + if len(text) < cap: + validation.append(line.strip()) + else: + training.append(line.strip()) + +with open(f'{indir}/train_culled.txt', 'w', encoding="utf-8") as f: + f.write("\n".join(training)) + +with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f: + f.write("\n".join(validation)) + +print(f"Culled {len(validation)} lines") \ No newline at end of file diff --git a/src/utils.py b/src/utils.py index c209b91..1176880 100755 --- a/src/utils.py +++ b/src/utils.py @@ -605,7 +605,7 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog # superfluous, but it cleans up some things class TrainingState(): - def __init__(self, config_path, keep_x_past_datasets=0, start=True, gpus=1): + def __init__(self, config_path, keep_x_past_checkpoints=0, start=True, gpus=1): # parse config to get its iteration with open(config_path, 'r') as file: self.config = yaml.safe_load(file) @@ -664,8 +664,8 @@ class TrainingState(): self.loss_milestones = [ 1.0, 0.15, 0.05 ] self.load_losses() - if keep_x_past_datasets > 0: - self.cleanup_old(keep=keep_x_past_datasets) + if keep_x_past_checkpoints > 0: + self.cleanup_old(keep=keep_x_past_checkpoints) if start: self.spawn_process(config_path=config_path, gpus=gpus) @@ -772,7 +772,7 @@ class TrainingState(): print("Removing", path) os.remove(path) - def parse(self, line, verbose=False, keep_x_past_datasets=0, buffer_size=8, progress=None ): + def parse(self, line, verbose=False, keep_x_past_checkpoints=0, buffer_size=8, progress=None ): self.buffer.append(f'{line}') should_return = False @@ -830,7 +830,7 @@ class TrainingState(): print(f'{"{:.3f}".format(percent*100)}% {message}') self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}') - self.cleanup_old(keep=keep_x_past_datasets) + self.cleanup_old(keep=keep_x_past_checkpoints) if line.find('%|') > 0: match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line) @@ -986,7 +986,7 @@ class TrainingState(): message, ) -def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)): +def run_training(config_path, verbose=False, gpus=1, keep_x_past_checkpoints=0, progress=gr.Progress(track_tqdm=True)): global training_state if training_state and training_state.process: return "Training already in progress" @@ -1008,13 +1008,13 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro unload_whisper() unload_voicefixer() - training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets, gpus=gpus) + training_state = TrainingState(config_path=config_path, keep_x_past_checkpoints=keep_x_past_checkpoints, gpus=gpus) for line in iter(training_state.process.stdout.readline, ""): if training_state.killed: return - result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_datasets=keep_x_past_datasets, progress=progress ) + result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_checkpoints=keep_x_past_checkpoints, progress=progress ) print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}") if result: yield result @@ -1164,7 +1164,7 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres for line in parsed_list: match = re.findall(r"^(.+?)_\d+\.wav$", line.split("|")[0]) - print(match) + if match is None or len(match) == 0: continue diff --git a/src/webui.py b/src/webui.py index abc73b3..4cedadb 100755 --- a/src/webui.py +++ b/src/webui.py @@ -559,7 +559,7 @@ def setup_gradio(): verbose_training = gr.Checkbox(label="Verbose Console Output", value=True) with gr.Row(): - training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1) + training_keep_x_past_checkpoints = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1) training_gpu_count = gr.Number(label="GPUs", value=get_device_count()) with gr.Row(): start_training_button = gr.Button(value="Train") @@ -777,7 +777,7 @@ def setup_gradio(): training_configs, verbose_training, training_gpu_count, - training_keep_x_past_datasets, + training_keep_x_past_checkpoints, ], outputs=[ training_output,