From fe8bf7a9d1bf800c06d31c2404e759874c574e9f Mon Sep 17 00:00:00 2001
From: mrq <barry.quiggles@protonmail.com>
Date: Tue, 7 Mar 2023 20:16:49 +0000
Subject: [PATCH] added helper script to cull short enough lines from training
 set as a validation set (if it yields good results doing validation during
 training, i'll add it to the web ui)

---
 src/cull_dataset.py | 35 +++++++++++++++++++++++++++++++++++
 src/utils.py        | 18 +++++++++---------
 src/webui.py        |  4 ++--
 3 files changed, 46 insertions(+), 11 deletions(-)
 create mode 100755 src/cull_dataset.py

diff --git a/src/cull_dataset.py b/src/cull_dataset.py
new file mode 100755
index 0000000..0572405
--- /dev/null
+++ b/src/cull_dataset.py
@@ -0,0 +1,35 @@
+import os
+import sys
+
+indir = f'./training/{sys.argv[1]}/'
+cap = int(sys.argv[2])
+
+if not os.path.isdir(indir):
+	raise Exception(f"Invalid directory: {indir}")
+
+if not os.path.exists(f'{indir}/train.txt'):
+	raise Exception(f"Missing dataset: {indir}/train.txt")
+
+with open(f'{indir}/train.txt', 'r', encoding="utf-8") as f:
+	lines = f.readlines()
+
+validation = []
+training = []
+
+for line in lines:
+	split = line.split("|")
+	filename = split[0]
+	text = split[1]
+
+	if len(text) < cap:
+		validation.append(line.strip())
+	else:
+		training.append(line.strip())
+
+with open(f'{indir}/train_culled.txt', 'w', encoding="utf-8") as f:
+	f.write("\n".join(training))
+
+with open(f'{indir}/validation.txt', 'w', encoding="utf-8") as f:
+	f.write("\n".join(validation))
+
+print(f"Culled {len(validation)} lines")
\ No newline at end of file
diff --git a/src/utils.py b/src/utils.py
index c209b91..1176880 100755
--- a/src/utils.py
+++ b/src/utils.py
@@ -605,7 +605,7 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog
 
 # superfluous, but it cleans up some things
 class TrainingState():
-	def __init__(self, config_path, keep_x_past_datasets=0, start=True, gpus=1):
+	def __init__(self, config_path, keep_x_past_checkpoints=0, start=True, gpus=1):
 		# parse config to get its iteration
 		with open(config_path, 'r') as file:
 			self.config = yaml.safe_load(file)
@@ -664,8 +664,8 @@ class TrainingState():
 		self.loss_milestones = [ 1.0, 0.15, 0.05 ]
 
 		self.load_losses()
-		if keep_x_past_datasets > 0:
-			self.cleanup_old(keep=keep_x_past_datasets)
+		if keep_x_past_checkpoints > 0:
+			self.cleanup_old(keep=keep_x_past_checkpoints)
 		if start:
 			self.spawn_process(config_path=config_path, gpus=gpus)
 
@@ -772,7 +772,7 @@ class TrainingState():
 			print("Removing", path)
 			os.remove(path)
 
-	def parse(self, line, verbose=False, keep_x_past_datasets=0, buffer_size=8, progress=None ):
+	def parse(self, line, verbose=False, keep_x_past_checkpoints=0, buffer_size=8, progress=None ):
 		self.buffer.append(f'{line}')
 
 		should_return = False
@@ -830,7 +830,7 @@ class TrainingState():
 				print(f'{"{:.3f}".format(percent*100)}% {message}')
 				self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
 
-				self.cleanup_old(keep=keep_x_past_datasets)
+				self.cleanup_old(keep=keep_x_past_checkpoints)
 
 			if line.find('%|') > 0:
 				match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
@@ -986,7 +986,7 @@ class TrainingState():
 			message,
 		)
 
-def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
+def run_training(config_path, verbose=False, gpus=1, keep_x_past_checkpoints=0, progress=gr.Progress(track_tqdm=True)):
 	global training_state
 	if training_state and training_state.process:
 		return "Training already in progress"
@@ -1008,13 +1008,13 @@ def run_training(config_path, verbose=False, gpus=1, keep_x_past_datasets=0, pro
 	unload_whisper()
 	unload_voicefixer()
 
-	training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets, gpus=gpus)
+	training_state = TrainingState(config_path=config_path, keep_x_past_checkpoints=keep_x_past_checkpoints, gpus=gpus)
 
 	for line in iter(training_state.process.stdout.readline, ""):
 		if training_state.killed:
 			return
 
-		result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_datasets=keep_x_past_datasets, progress=progress )
+		result, percent, message = training_state.parse( line=line, verbose=verbose, keep_x_past_checkpoints=keep_x_past_checkpoints, progress=progress )
 		print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
 		if result:
 			yield result
@@ -1164,7 +1164,7 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
 
 		for line in parsed_list:
 			match = re.findall(r"^(.+?)_\d+\.wav$", line.split("|")[0])
-			print(match)
+
 			if match is None or len(match) == 0:
 				continue
 			
diff --git a/src/webui.py b/src/webui.py
index abc73b3..4cedadb 100755
--- a/src/webui.py
+++ b/src/webui.py
@@ -559,7 +559,7 @@ def setup_gradio():
 						verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
 						
 						with gr.Row():
-							training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
+							training_keep_x_past_checkpoints = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
 							training_gpu_count = gr.Number(label="GPUs", value=get_device_count())
 						with gr.Row():
 							start_training_button = gr.Button(value="Train")
@@ -777,7 +777,7 @@ def setup_gradio():
 				training_configs,
 				verbose_training,
 				training_gpu_count,
-				training_keep_x_past_datasets,
+				training_keep_x_past_checkpoints,
 			],
 			outputs=[
 				training_output,