From 5be14abc216660171cb043d409a0a0cfef61e4c9 Mon Sep 17 00:00:00 2001
From: mrq <barry.quiggles@protonmail.com>
Date: Sun, 5 Mar 2023 23:55:27 +0000
Subject: [PATCH] UI cleanup, actually fix syncing the epoch counter (i hope),
 setting auto-suggest voice chunk size whatever to 0 will just split based on
 the average duration length, signal when a NaN info value is detected
 (there's some safeties in the training, but it will inevitably fuck the
 model)

---
 src/utils.py | 23 +++++++++++++++--------
 src/webui.py | 47 ++++++++++++++++++++++++++++-------------------
 2 files changed, 43 insertions(+), 27 deletions(-)

diff --git a/src/utils.py b/src/utils.py
index 8424c6c..41dd15f 100755
--- a/src/utils.py
+++ b/src/utils.py
@@ -233,7 +233,7 @@ def generate(
 		if emotion == "Custom":
 			if prompt and prompt.strip() != "":
 				cut_text = f"[{prompt},] {cut_text}"
-		else:
+		elif emotion != "None":
 			cut_text = f"[I am really {emotion.lower()},] {cut_text}"
 
 		progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]'
@@ -464,14 +464,21 @@ def update_baseline_for_latents_chunks( voice ):
 		return 1
 
 	files = os.listdir(path)
+	
+	total = 0
 	total_duration = 0
+
 	for file in files:
 		if file[-4:] != ".wav":
 			continue
+
 		metadata = torchaudio.info(f'{path}/{file}')
 		duration = metadata.num_channels * metadata.num_frames / metadata.sample_rate
 		total_duration += duration
+		total = total + 1
 
+	if args.autocalculate_voice_chunk_duration_size == 0:
+		return int(total_duration / total) if total > 0 else 1
 	return int(total_duration / args.autocalculate_voice_chunk_duration_size) if total_duration > 0 else 1
 
 def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
@@ -550,6 +557,8 @@ class TrainingState():
 		self.eta = "?"
 		self.eta_hhmmss = "?"
 
+		self.nan_detected = False
+
 		self.last_info_check_at = 0
 		self.statistics = []
 		self.losses = []
@@ -701,13 +710,10 @@ class TrainingState():
 				info_line = line.split("INFO:")[-1]
 				# to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point
 				if ': nan' in info_line:
-					should_return = True
-
-					print("! NAN DETECTED !")
-					self.buffer.append("! NAN DETECTED !")
+					self.nan_detected = True
 
 				# easily rip out our stats...
-				match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', info_line)
+				match = re.findall(r'\b([a-z_0-9]+?)\b: *?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', info_line)
 				if match and len(match) > 0:
 					for k, v in match:
 						self.info[k] = float(v.replace(",", ""))
@@ -862,6 +868,8 @@ class TrainingState():
 			self.metrics['loss'] = ", ".join(self.metrics['loss'])
 
 			message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}]\n[{self.metrics['loss']}]"
+			if self.nan_detected:
+				message = f"[!NaN DETECTED!] {message}"
 
 			if message:
 				percent = self.it / float(self.its) # self.epoch / float(self.epochs)
@@ -965,7 +973,6 @@ def stop_training():
 	try:
 		children = [p.info for p in psutil.process_iter(attrs=['pid', 'name', 'cmdline']) if './src/train.py' in p.info['cmdline']]
 	except Exception as e:
-		print(e)
 		pass
 
 	training_state.process.stdout.close()
@@ -1419,7 +1426,7 @@ def setup_args():
 		'prune-nonfinal-outputs': True,
 		'use-bigvgan-vocoder': True,
 		'concurrency-count': 2,
-		'autocalculate-voice-chunk-duration-size': 10,
+		'autocalculate-voice-chunk-duration-size': 0,
 		'output-sample-rate': 44100,
 		'output-volume': 1,
 		
diff --git a/src/webui.py b/src/webui.py
index 3ca76d0..f9b06ca 100755
--- a/src/webui.py
+++ b/src/webui.py
@@ -180,9 +180,9 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
 
 	return (
 		gr.update(value=j, visible=j is not None),
-		gr.update(visible=j is not None),
 		gr.update(value=latents, visible=latents is not None),
-		None if j is None else j['voice']
+		None if j is None else j['voice'],
+		gr.update(visible=j is not None),
 	)
 
 def prepare_dataset_proxy( voice, language, progress=gr.Progress(track_tqdm=True) ):
@@ -378,15 +378,15 @@ def setup_gradio():
 		with gr.Tab("Generate"):
 			with gr.Row():
 				with gr.Column():
-					text = gr.Textbox(lines=4, label="Prompt")
+					text = gr.Textbox(lines=4, label="Input Prompt")
 			with gr.Row():
 				with gr.Column():
 					delimiter = gr.Textbox(lines=1, label="Line Delimiter", placeholder="\\n")
 
-					emotion = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom"], value="Custom", label="Emotion", type="value", interactive=True )
-					prompt = gr.Textbox(lines=1, label="Custom Emotion + Prompt (if selected)")
+					emotion = gr.Radio( ["Happy", "Sad", "Angry", "Disgusted", "Arrogant", "Custom", "None"], value="None", label="Emotion", type="value", interactive=True )
+					prompt = gr.Textbox(lines=1, label="Custom Emotion")
 					voice = gr.Dropdown(choices=voice_list_with_defaults, label="Voice", type="value", value=voice_list_with_defaults[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit
-					mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath" )
+					mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath", visible=False )
 					voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=128, value=1, step=1)
 					with gr.Row():
 						refresh_voices = gr.Button(value="Refresh Voice List")
@@ -397,6 +397,11 @@ def setup_gradio():
 						inputs=voice,
 						outputs=voice_latents_chunks
 					)
+					voice.change(
+						fn=lambda value: gr.update(visible=value == "microphone"),
+						inputs=voice,
+						outputs=mic_audio,
+					)
 				with gr.Column():
 					candidates = gr.Slider(value=1, minimum=1, maximum=6, step=1, label="Candidates")
 					seed = gr.Number(value=0, precision=0, label="Seed")
@@ -406,16 +411,17 @@ def setup_gradio():
 					diffusion_iterations = gr.Slider(value=128, minimum=0, maximum=512, step=1, label="Iterations")
 
 					temperature = gr.Slider(value=0.2, minimum=0, maximum=1, step=0.1, label="Temperature")
-					breathing_room = gr.Slider(value=8, minimum=1, maximum=32, step=1, label="Pause Size")
-					diffusion_sampler = gr.Radio(
-						["P", "DDIM"], # + ["K_Euler_A", "DPM++2M"],
-						value="P", label="Diffusion Samplers", type="value" )
 					show_experimental_settings = gr.Checkbox(label="Show Experimental Settings")
 					reset_generation_settings_button = gr.Button(value="Reset to Default")
 				with gr.Column(visible=False) as col:
 					experimental_column = col
 
 					experimental_checkboxes = gr.CheckboxGroup(["Half Precision", "Conditioning-Free"], value=["Conditioning-Free"], label="Experimental Flags")
+					breathing_room = gr.Slider(value=8, minimum=1, maximum=32, step=1, label="Pause Size")
+					diffusion_sampler = gr.Radio(
+						["P", "DDIM"], # + ["K_Euler_A", "DPM++2M"],
+						value="DDIM", label="Diffusion Samplers", type="value"
+					)
 					cvvp_weight = gr.Slider(value=0, minimum=0, maximum=1, label="CVVP Weight")
 					top_p = gr.Slider(value=0.8, minimum=0, maximum=1, label="Top P")
 					diffusion_temperature = gr.Slider(value=1.0, minimum=0, maximum=1, label="Diffusion Temperature")
@@ -460,10 +466,12 @@ def setup_gradio():
 					audio_in = gr.Files(type="file", label="Audio Input", file_types=["audio"])
 					import_voice_name = gr.Textbox(label="Voice Name")
 					import_voice_button = gr.Button(value="Import Voice")
-				with gr.Column():
-					metadata_out = gr.JSON(label="Audio Metadata", visible=False)
-					copy_button = gr.Button(value="Copy Settings", visible=False)
-					latents_out = gr.File(type="binary", label="Voice Latents", visible=False)
+				with gr.Column(visible=False) as col:
+					utilities_metadata_column = col
+
+					metadata_out = gr.JSON(label="Audio Metadata")
+					copy_button = gr.Button(value="Copy Settings")
+					latents_out = gr.File(type="binary", label="Voice Latents")
 		with gr.Tab("Training"):
 			with gr.Tab("Prepare Dataset"):
 				with gr.Row():
@@ -662,9 +670,9 @@ def setup_gradio():
 			inputs=audio_in,
 			outputs=[
 				metadata_out,
-				copy_button,
 				latents_out,
-				import_voice_name
+				import_voice_name,
+				utilities_metadata_column,
 			]
 		)
 
@@ -697,9 +705,10 @@ def setup_gradio():
 			outputs=voice,
 		)
 		
-		prompt.change(fn=lambda value: gr.update(value="Custom"),
-			inputs=prompt,
-			outputs=emotion
+		emotion.change(
+			fn=lambda value: gr.update(visible=value == "Custom"),
+			inputs=emotion,
+			outputs=prompt
 		)
 		mic_audio.change(fn=lambda value: gr.update(value="microphone"),
 			inputs=mic_audio,